diff --git a/cmake/CPM.cmake b/cmake/CPM.cmake index a78671296c..baf2d8c344 100644 --- a/cmake/CPM.cmake +++ b/cmake/CPM.cmake @@ -2,8 +2,8 @@ # # SPDX-FileCopyrightText: Copyright (c) 2019-2023 Lars Melchior and contributors -set(CPM_DOWNLOAD_VERSION 0.38.5) -set(CPM_HASH_SUM "192aa0ccdc57dfe75bd9e4b176bf7fb5692fd2b3e3f7b09c74856fc39572b31c") +set(CPM_DOWNLOAD_VERSION 0.40.2) +set(CPM_HASH_SUM "c8cdc32c03816538ce22781ed72964dc864b2a34a310d3b7104812a5ca2d835d") if(CPM_SOURCE_CACHE) set(CPM_DOWNLOAD_LOCATION "${CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") diff --git a/include/accelerator.h b/include/accelerator.h index 2711859663..a1113c57be 100644 --- a/include/accelerator.h +++ b/include/accelerator.h @@ -46,7 +46,12 @@ namespace quda * @param out Solution vector. * @param in Right-hand side. */ - virtual void operator()(ColorSpinorField &out, ColorSpinorField &in) + virtual void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in) { if (transformer.trained) { transformer.apply(*base_solver, out, in); @@ -64,7 +69,7 @@ namespace quda * @param null Solver to solve for null vectors. * @param in meta color spinor field. */ - virtual void train_param(Solver &null, ColorSpinorField &in) + virtual void train_param(Solver &null, const ColorSpinorField &in) override { if (!active_training && !transformer.trained) { active_training = true; diff --git a/include/blas_quda.h b/include/blas_quda.h index b7da601e38..77b3c404f3 100644 --- a/include/blas_quda.h +++ b/include/blas_quda.h @@ -30,6 +30,7 @@ namespace quda { inline void copy(cvector_ref &dst, cvector_ref &src) { + if (dst.size() != src.size()) errorQuda("Mismatched vector sets %lu != %lu", dst.size(), src.size()); for (auto i = 0u; i < src.size(); i++) { dst[i].copy(src[i]); } } @@ -293,7 +294,7 @@ namespace quda { inline array max_deviation(const ColorSpinorField &x, const ColorSpinorField &y) { - return max_deviation(cvector_ref(x), cvector_ref(y)); + return max_deviation(cvector_ref(x), cvector_ref(y))[0]; } /** @@ -302,13 +303,15 @@ namespace quda { */ cvector norm1(cvector_ref &x); + inline double norm1(const ColorSpinorField &x) { return norm1(cvector_ref {x})[0]; } + /** @brief Compute the L2 norm (||x||^2) of a field @param[in] x The field we are reducing */ cvector norm2(cvector_ref &x); - inline double norm2(const ColorSpinorField &x) { return norm2(cvector_ref {x}); } + inline double norm2(const ColorSpinorField &x) { return norm2(cvector_ref {x})[0]; } /** @brief Compute y += a * x and then (x, y) @@ -319,6 +322,11 @@ namespace quda { cvector axpyReDot(cvector &a, cvector_ref &x, cvector_ref &y); + inline double axpyReDot(double a, const ColorSpinorField &x, ColorSpinorField &y) + { + return axpyReDot(cvector(a), cvector_ref(x), y)[0]; + } + /** @brief Compute the real-valued inner product (x, y) @param[in] x input vector @@ -328,7 +336,7 @@ namespace quda { inline double reDotProduct(const ColorSpinorField &x, const ColorSpinorField &y) { - return reDotProduct(cvector_ref {x}, cvector_ref {y}); + return reDotProduct(cvector_ref {x}, cvector_ref {y})[0]; } /** @@ -342,6 +350,12 @@ namespace quda { cvector axpbyzNorm(cvector &a, cvector_ref &x, cvector &b, cvector_ref &y, cvector_ref &z); + inline double axpbyzNorm(double a, const ColorSpinorField &x, double b, const ColorSpinorField &y, + ColorSpinorField &z) + { + return axpbyzNorm(cvector(a), cvector_ref(x), cvector(b), y, z)[0]; + } + /** @brief Compute y += a * x and then ||y||^2 @param[in] a scalar multiplier @@ -354,6 +368,11 @@ namespace quda { return axpbyzNorm(a, x, 1.0, y, y); } + inline double axpyNorm(double a, const ColorSpinorField &x, ColorSpinorField &y) + { + return axpyNorm(a, cvector_ref {x}, cvector_ref {y})[0]; + } + /** @brief Compute the complex-valued inner product (x, y) @param[in] x input vector @@ -363,7 +382,7 @@ namespace quda { inline Complex cDotProduct(const ColorSpinorField &x, const ColorSpinorField &y) { - return cDotProduct(cvector_ref {x}, cvector_ref {y}); + return cDotProduct(cvector_ref {x}, cvector_ref {y})[0]; } /** @@ -373,6 +392,11 @@ namespace quda { */ cvector cDotProductNormAB(cvector_ref &x, cvector_ref &y); + inline double4 cDotProductNormAB(const ColorSpinorField &x, const ColorSpinorField &y) + { + return cDotProductNormAB(cvector_ref {x}, cvector_ref {y})[0]; + } + /** @brief Return complex-valued inner product (x,y) and ||x||^2 @param[in] x input vector @@ -387,6 +411,11 @@ namespace quda { return a; } + inline double3 cDotProductNormA(const ColorSpinorField &x, const ColorSpinorField &y) + { + return cDotProductNormA(cvector_ref {x}, cvector_ref {y})[0]; + } + /** @brief Return complex-valued inner product (x,y) and ||y||^2 @param[in] x input vector @@ -401,6 +430,11 @@ namespace quda { return a; } + inline double3 cDotProductNormB(const ColorSpinorField &x, const ColorSpinorField &y) + { + return cDotProductNormB(cvector_ref {x}, cvector_ref {y})[0]; + } + /** @brief Apply the operation z += a * x + b * y, y -= b * w, compute complex-valued inner product (u, y) and ||y||^2 @@ -418,6 +452,14 @@ namespace quda { cvector_ref &w, cvector_ref &u); + inline double3 caxpbypzYmbwcDotProductUYNormY(const Complex &a, const ColorSpinorField &x, const Complex &b, + ColorSpinorField &y, ColorSpinorField &z, const ColorSpinorField &w, + const ColorSpinorField &u) + { + return caxpbypzYmbwcDotProductUYNormY(cvector(a), cvector_ref(x), b, y, z, w, + u)[0]; + } + /** @brief Compute y = a * x + b * y and then ||y||^2 @param[in] a scalar multiplier @@ -440,6 +482,11 @@ namespace quda { return caxpbyNorm(a, x, 1.0, y); } + inline double caxpyNorm(const Complex &a, const ColorSpinorField &x, ColorSpinorField &y) + { + return caxpyNorm(a, cvector_ref {x}, cvector_ref {y})[0]; + } + /** @brief Compute y -= x and then ||y||^2 @param[in] x input vector @@ -450,6 +497,11 @@ namespace quda { return caxpbyNorm(1.0, x, -1.0, y); } + inline double xmyNorm(const ColorSpinorField &x, ColorSpinorField &y) + { + return xmyNorm(cvector_ref {x}, cvector_ref {y})[0]; + } + /** @brief Compute z = a * b * x + y, x = a * x, and then ||z||^2 @param[in] a scalar multiplier @@ -461,6 +513,12 @@ namespace quda { cvector cabxpyzAxNorm(cvector &a, cvector &b, cvector_ref &x, cvector_ref &y, cvector_ref &z); + inline double cabxpyzAxNorm(double a, const Complex &b, ColorSpinorField &x, const ColorSpinorField &y, + ColorSpinorField &z) + { + return cabxpyzAxNorm(cvector(a), cvector(b), cvector_ref(x), y, z)[0]; + } + /** @brief Compute y += a * x and the resulting complex-valued inner product (z, y) @param[in] a scalar multiplier @@ -471,6 +529,11 @@ namespace quda { cvector caxpyDotzy(cvector &a, cvector_ref &x, cvector_ref &y, cvector_ref &z); + inline Complex caxpyDotzy(const Complex &a, const ColorSpinorField &x, ColorSpinorField &y, const ColorSpinorField &z) + { + return caxpyDotzy(cvector(a), cvector_ref(x), y, z)[0]; + } + /** @brief Compute y += a * x and then compute ||y||^2 and real-valued inner product (y_out, y_out-y_in) @@ -481,6 +544,11 @@ namespace quda { cvector axpyCGNorm(cvector &a, cvector_ref &x, cvector_ref &y); + inline double2 axpyCGNorm(double a, const ColorSpinorField &x, ColorSpinorField &y) + { + return axpyCGNorm(cvector(a), cvector_ref(x), y)[0]; + } + /** @brief Computes ||x||^2, ||r||^2 and the MILC/FNAL heavy quark residual norm @@ -492,7 +560,7 @@ namespace quda { inline double3 HeavyQuarkResidualNorm(const ColorSpinorField &x, const ColorSpinorField &r) { - return HeavyQuarkResidualNorm(cvector_ref(x), cvector_ref(r)); + return HeavyQuarkResidualNorm(cvector_ref(x), cvector_ref(r))[0]; } /** @@ -510,7 +578,7 @@ namespace quda { const ColorSpinorField &r) { return xpyHeavyQuarkResidualNorm(cvector_ref(x), cvector_ref(y), - cvector_ref(r)); + cvector_ref(r))[0]; } /** @@ -522,6 +590,11 @@ namespace quda { cvector tripleCGReduction(cvector_ref &x, cvector_ref &y, cvector_ref &z); + inline double3 tripleCGReduction(const ColorSpinorField &x, const ColorSpinorField &y, const ColorSpinorField &z) + { + return tripleCGReduction(cvector_ref(x), y, z)[0]; + } + /** @brief Computes ||x||^2, ||y||^2, the real-valued inner product (y, z), and ||z||^2 @param[in] x input vector @@ -531,6 +604,11 @@ namespace quda { cvector quadrupleCGReduction(cvector_ref &x, cvector_ref &y, cvector_ref &z); + inline double4 quadrupleCGReduction(const ColorSpinorField &x, const ColorSpinorField &y, const ColorSpinorField &z) + { + return quadrupleCGReduction(cvector_ref(x), y, z)[0]; + } + /** @brief Computes z = x, w = y, x += a * y, y -= a * v and ||y||^2 @param[in] a scalar multiplier @@ -544,6 +622,12 @@ namespace quda { cvector_ref &y, cvector_ref &z, cvector_ref &w, cvector_ref &v); + inline double quadrupleCG3InitNorm(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, + ColorSpinorField &w, const ColorSpinorField &v) + { + return quadrupleCG3InitNorm(cvector(a), cvector_ref(x), y, z, w, v)[0]; + } + /** @brief Computes x = b * (x + a * y) + ( 1 - b) * z, y = b * (y + a * v) + (1 - b) * w, z = x_in, w = y_in, and @@ -560,6 +644,12 @@ namespace quda { cvector_ref &y, cvector_ref &z, cvector_ref &w, cvector_ref &v); + inline double quadrupleCG3UpdateNorm(double a, double b, ColorSpinorField &x, ColorSpinorField &y, + ColorSpinorField &z, ColorSpinorField &w, const ColorSpinorField &v) + { + return quadrupleCG3UpdateNorm(cvector(a), b, cvector_ref(x), y, z, w, v)[0]; + } + namespace block { diff --git a/include/color_spinor_field.h b/include/color_spinor_field.h index bbea1b7842..18de570604 100644 --- a/include/color_spinor_field.h +++ b/include/color_spinor_field.h @@ -943,6 +943,36 @@ namespace quda void resize(std::vector &v, size_t new_size, QudaFieldCreate create, const ColorSpinorField &src = ColorSpinorField()); + /** + @brief Create a vector of fields that aliases another vector of + fields' storage. The alias field can use a different precision + than this field, though it cannot be greater. This + functionality is useful for the case where we have multiple + temporaries in different precisions, but do not need them + simultaneously. Use this functionality with caution. + @param[out] alias The vector of aliased fields + @param[in] v The vector of fields to alias + @param[in] param Parameters for the alias field + */ + void create_alias(cvector_ref &alias, cvector_ref &v, + const ColorSpinorParam ¶m = ColorSpinorParam()); + + /** + @brief Create a vector of fields that aliases another vector of + fields' storage. The alias field can use a different precision + than this field, though it cannot be greater. This functionality + is useful for the case where we have multiple temporaries in + different precisions, but do not need them simultaneously. This + variant is used with std::vector as opposed to vector_ref, and + allows for correct resizing. Use this functionality with + caution. + @param[out] alias The vector of aliased fields + @param[in] v The vector of fields to alias + @param[in] param Parameters for the alias field + */ + void create_alias(std::vector &alias, cvector_ref &v, + const ColorSpinorParam ¶m = ColorSpinorParam()); + void copyGenericColorSpinor(ColorSpinorField &dst, const ColorSpinorField &src, QudaFieldLocation location, void *Dst = nullptr, const void *Src = nullptr); diff --git a/include/dirac_quda.h b/include/dirac_quda.h index a5ba6e0966..a14daef015 100644 --- a/include/dirac_quda.h +++ b/include/dirac_quda.h @@ -542,7 +542,6 @@ namespace quda { DiracWilson(const DiracWilson &dirac); DiracWilson(const DiracParam ¶m, const int nDims); // to correctly adjust face for DW and non-deg twisted mass - virtual ~DiracWilson(); DiracWilson& operator=(const DiracWilson &dirac); virtual void Dslash(cvector_ref &out, cvector_ref &in, @@ -592,7 +591,6 @@ namespace quda { public: DiracWilsonPC(const DiracParam ¶m); DiracWilsonPC(const DiracWilsonPC &dirac); - virtual ~DiracWilsonPC(); DiracWilsonPC& operator=(const DiracWilsonPC &dirac); void M(cvector_ref &out, cvector_ref &in) const override; @@ -1962,7 +1960,6 @@ namespace quda { @param[in] param Parameters defining this operator */ DiracCoarse(const DiracCoarse &dirac, const DiracParam ¶m); - virtual ~DiracCoarse(); virtual bool isCoarse() const override { return true; } @@ -2108,8 +2105,6 @@ namespace quda { */ DiracCoarsePC(const DiracCoarse &dirac, const DiracParam ¶m); - virtual ~DiracCoarsePC(); - /** @brief Apply preconditioned Dslash out = (D * in) @param[out] out Output field diff --git a/include/dslash.h b/include/dslash.h index 2643fbbc4e..4b1aada52e 100644 --- a/include/dslash.h +++ b/include/dslash.h @@ -156,8 +156,8 @@ namespace quda } } - virtual int blockStep() const override { return 16; } - virtual int blockMin() const override { return 16; } + virtual int blockStep() const override { return (arg.shmem & 64) ? 8 : 16; } + virtual int blockMin() const override { return (arg.shmem & 64) ? 8 : 16; } unsigned int maxSharedBytesPerBlock() const override { return maxDynamicSharedBytesPerBlock(); } @@ -207,13 +207,12 @@ namespace quda virtual void initTuneParam(TuneParam ¶m) const override { - /* for nvshmem uber kernels the current synchronization requires use to keep the y and z dimension local to the + /* for nvshmem uber kernels the current synchronization requires us to keep the y and z dimension local to the * block. This can be removed when we introduce a finer grained synchronization which takes into account the y and * z components explicitly */ - if (arg.shmem & 64) { - step_y = vector_length_y; - step_z = vector_length_z; - } + step_y = arg.shmem & 64 ? vector_length_y : step_y_bkup; + step_z = arg.shmem & 64 ? vector_length_z : step_z_bkup; + TunableKernel3D::initTuneParam(param); if (arg.pack_threads && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL)) param.aux.x = 1; // packing blocks per direction @@ -225,10 +224,9 @@ namespace quda /* for nvshmem uber kernels the current synchronization requires use to keep the y and z dimension local to the * block. This can be removed when we introduce a finer grained synchronization which takes into account the y and * z components explicitly. */ - if (arg.shmem & 64) { - step_y = vector_length_y; - step_z = vector_length_z; - } + step_y = arg.shmem & 64 ? vector_length_y : step_y_bkup; + step_z = arg.shmem & 64 ? vector_length_z : step_z_bkup; + TunableKernel3D::defaultTuneParam(param); if (arg.pack_threads && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL)) param.aux.x = 1; // packing blocks per direction diff --git a/include/eigen_helper.h b/include/eigen_helper.h index f4d91d74b5..8c1bf012a8 100644 --- a/include/eigen_helper.h +++ b/include/eigen_helper.h @@ -10,8 +10,19 @@ #endif #include + +// hide annoying warning +#if !defined(__clang__) && !defined(_NVHPC_CUDA) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif + #include #include #include +#if !defined(__clang__) && !defined(_NVHPC_CUDA) +#pragma GCC diagnostic pop +#endif + using namespace Eigen; diff --git a/include/invert_quda.h b/include/invert_quda.h index 816313fbd4..d9d5279066 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -14,19 +14,6 @@ namespace quda { -// temporary addition until multi-RHS for all Dirac operator functions -#ifdef __CUDACC__ -#ifdef __NVCC_DIAG_PRAGMA_SUPPORT__ -#pragma nv_diag_suppress 611 -#else -#pragma diag_suppress 611 -#endif -#endif - -#ifdef __NVCOMPILER -#pragma diag_suppress partial_override -#endif - /** SolverParam is the meta data used to define linear solvers. */ @@ -111,7 +98,7 @@ namespace quda { int max_hq_res_increase = 0; /**< This parameter determines how many total heavy-quark residual - restarts we tolerate before terminating the solver, i.e., how long + restarts we tolerate before terminating the solver, i.e., how long do we want to keep trying to converge */ int max_hq_res_restart_total = 0; @@ -137,10 +124,10 @@ namespace quda { bool sloppy_converge = false; /**< Actual L2 residual norm achieved in solver */ - double true_res = 0.0; + vector true_res; /**< Actual heavy quark residual norm achieved in solver */ - double true_res_hq = 0.0; + vector true_res_hq; /**< Maximum number of iterations in the linear solver */ int maxiter = 0; @@ -305,8 +292,8 @@ namespace quda { tol_restart(param.tol_restart), tol_hq(param.tol_hq), compute_true_res(param.compute_true_res), - true_res(param.true_res), - true_res_hq(param.true_res_hq), + true_res(param.num_src, 0.0), + true_res_hq(param.num_src, 0.0), maxiter(param.maxiter), iter(param.iter), precision(param.cuda_prec), @@ -368,35 +355,12 @@ namespace quda { SolverParam(const SolverParam ¶m) = default; /** - Update the QudaInvertParam with the data from this - @param param the QudaInvertParam to be updated + @brief Update the QudaInvertParam with the data from this + instance (update the true residuals, and other observables). + @param[in,out] param the QudaInvertParam to be updated + @param[in] offset offset applied to the */ - void updateInvertParam(QudaInvertParam ¶m, int offset=-1) { - param.true_res = true_res; - param.true_res_hq = true_res_hq; - param.iter += iter; - if (offset >= 0) { - param.true_res_offset[offset] = true_res_offset[offset]; - param.iter_res_offset[offset] = iter_res_offset[offset]; - param.true_res_hq_offset[offset] = true_res_hq_offset[offset]; - } else { - for (int i=0; i(param.eig_param) = eig_param; - } + void updateInvertParam(QudaInvertParam ¶m, int offset = -1); // for incremental eigCG: void updateRhsIndex(QudaInvertParam ¶m) { rhs_idx = param.rhs_idx; } @@ -422,31 +386,38 @@ namespace quda { bool mixed() { return param.precision != param.precision_sloppy; } + /** + @brief Check the support of each source field, and return true + if all fields in the set have zero support. If we are doing + null-space finding, this function always returns false. If a + given source vector does have zero support, then we set the + matching solution vector to match. + @param[in] x Solution vector set + @param[in] b Source vector set + @param[in] b2 Vector of norms + @return boolean if all vectors have zero support + */ + bool is_zero_src(cvector_ref &x, cvector_ref &b, cvector &b2); + public: Solver(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); virtual ~Solver(); - virtual void operator()(ColorSpinorField &out, ColorSpinorField &in) = 0; - /** @brief Naive loop over RHS, for solvers that are not yet multi-RHS aware */ - void operator()(cvector_ref &out, cvector_ref &in) - { - for (auto i = 0u; i < in.size(); i++) { this->operator()(out[i], in[i]); } - } + virtual void operator()(cvector_ref &out, cvector_ref &in) = 0; virtual void blocksolve(ColorSpinorField &out, ColorSpinorField &in); /** @return Return the residual vector from the prior solve */ - virtual ColorSpinorField &get_residual() + virtual cvector_ref get_residual() { errorQuda("Not implemented"); - static ColorSpinorField dummy; - return dummy; + return cvector_ref(); } /** @@ -455,7 +426,7 @@ namespace quda { @param Solver the solver to be used to collect the null space vectors. @param ColorSpinorField the vector used to perform the training. */ - virtual void train_param(Solver &, ColorSpinorField &) + virtual void train_param(Solver &, const ColorSpinorField &) { // Do nothing } @@ -464,9 +435,10 @@ namespace quda { @brief a virtual method that performs the inversion and collect some vectors. The default here is a no-op and should not be called. */ - virtual void solve_and_collect(ColorSpinorField &, ColorSpinorField &, cvector_ref &, int, double) + virtual void solve_and_collect(cvector_ref &, cvector_ref &, + cvector_ref &, int, double) { - errorQuda("NOT implemented."); + errorQuda("Not implemented."); } void set_tol(double tol) { param.tol = tol; } @@ -492,7 +464,7 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); /** @brief Solver factory @@ -546,17 +518,22 @@ namespace quda { @param[in] residual_type The type of residual we want to solve for @return L2 stopping condition */ - static double stopping(double tol, double b2, QudaResidualType residual_type); + static vector stopping(double tol, cvector &b2, QudaResidualType residual_type); + + static inline double stopping(double tol, double b2, QudaResidualType residual_type) + { + return stopping(tol, cvector(b2), residual_type)[0]; + } /** - @briefTest for solver convergence + @brief Test for solver convergence @param[in] r2 L2 norm squared of the residual @param[in] hq2 Heavy quark residual @param[in] r2_tol Solver L2 tolerance @param[in] hq_tol Solver heavy-quark tolerance @return Whether converged */ - bool convergence(double r2, double hq2, double r2_tol, double hq_tol); + bool convergence(cvector &r2, cvector &hq2, cvector &r2_tol, cvector &hq_tol); /** @brief Test for HQ solver convergence -- ignore L2 residual @@ -566,7 +543,7 @@ namespace quda { @param[in[ hq_tol Solver heavy-quark tolerance @return Whether converged */ - bool convergenceHQ(double r2, double hq2, double r2_tol, double hq_tol); + bool convergenceHQ(cvector &hq2, cvector &hq_tol); /** @brief Test for L2 solver convergence -- ignore HQ residual @@ -575,7 +552,7 @@ namespace quda { @param[in] r2_tol Solver L2 tolerance @param[in] hq_tol Solver heavy-quark tolerance */ - bool convergenceL2(double r2, double hq2, double r2_tol, double hq_tol); + bool convergenceL2(cvector &r2, cvector &r2_tol); /** @brief Prints out the running statistics of the solver @@ -585,7 +562,7 @@ namespace quda { @param[in] r2 L2 norm squared of the residual @param[in] hq2 Heavy quark residual */ - void PrintStats(const char *name, int k, double r2, double b2, double hq2); + void PrintStats(const char *name, int k, cvector &r2, cvector &b2, cvector &hq2 = {}); /** @brief Prints out the summary of the solver convergence @@ -598,7 +575,8 @@ namespace quda { @param[in] r2_tol Solver L2 tolerance @param[in] hq_tol Solver heavy-quark tolerance */ - void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol); + void PrintSummary(const char *name, int k, cvector &r2, cvector &b2, cvector &r2_tol, + cvector &hq_tol = {}); /** @brief Returns the epsilon tolerance for a given precision, by default returns @@ -667,10 +645,8 @@ namespace quda { @brief Compute power iterations on a Dirac matrix @param[in] diracm Dirac matrix used for power iterations @param[in] start Starting rhs for power iterations; value preserved unless it aliases tempvec1 or tempvec2 - @param[in,out] tempvec1 Temporary vector used for power iterations (FIXME: can become a reference when std::swap - can be used on ColorSpinorField) - @param[in,out] tempvec2 Temporary vector used for power iterations (FIXME: can become a reference when std::swap - can be used on ColorSpinorField) + @param[in,out] tempvec1 Temporary vector used for power iterations + @param[in,out] tempvec2 Temporary vector used for power iterations @param[in] niter Total number of power iteration iterations @param[in] normalize_freq Frequency with which intermediate vector gets normalized @param[in] args Parameter pack of ColorSpinorFields used as temporary passed to Dirac @@ -684,8 +660,8 @@ namespace quda { /** @brief Generate a Krylov space in a given basis @param[in] diracm Dirac matrix used to generate the Krylov space - @param[out] Ap dirac matrix times the Krylov basis vectors - @param[in,out] p Krylov basis vectors; assumes p[0] is in place + @param[out] Ap dirac matrix times the Krylov basis vector sets + @param[in,out] p Krylov basis vector sets; assumes p[0] is in place @param[in] n_krylov Size of krylov space @param[in] basis Basis type @param[in] m_map Slope mapping for Chebyshev basis; ignored for power basis @@ -693,9 +669,32 @@ namespace quda { @param[in] args Parameter pack of ColorSpinorFields used as temporary passed to Dirac */ template - static void computeCAKrylovSpace(const DiracMatrix &diracm, std::vector &Ap, - std::vector &p, int n_krylov, QudaCABasis basis, double m_map, - double b_map, Args &&...args); + static void computeCAKrylovSpace(const DiracMatrix &diracm, std::vector> &Ap, + std::vector> &p, int n_krylov, QudaCABasis basis, + double m_map, double b_map, Args &&...args); + + // FIXME delete this variant once CA-CG is MRHS aware + template + void computeCAKrylovSpace(const DiracMatrix &diracm, std::vector &Ap, + std::vector &p, int n_krylov, QudaCABasis basis, double m_map, + double b_map, Args &&...args) + { + std::vector> p2(p.size()); + for (auto i = 0u; i < p.size(); i++) { + p2[i].resize(1); + p2[i][0] = std::move(p[i]); + } + std::vector> Ap2(Ap.size()); + for (auto i = 0u; i < Ap.size(); i++) { + Ap2[i].resize(1); + Ap2[i][0] = std::move(Ap[i]); + } + + computeCAKrylovSpace(diracm, Ap2, p2, n_krylov, basis, m_map, b_map, args...); + + for (auto i = 0u; i < p.size(); i++) p[i] = std::move(p2[i][0]); + for (auto i = 0u; i < Ap.size(); i++) Ap[i] = std::move(Ap2[i][0]); + } }; /** @@ -704,13 +703,13 @@ namespace quda { class CG : public Solver { private: - ColorSpinorField y; - ColorSpinorField r; - ColorSpinorField rnew; - ColorSpinorField p; - ColorSpinorField Ap; - ColorSpinorField rSloppy; - ColorSpinorField xSloppy; + std::vector y; + std::vector r; + std::vector rnew; + std::vector p; + std::vector Ap; + std::vector r_sloppy; + std::vector x_sloppy; bool init = false; /** @@ -718,18 +717,23 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: CG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); virtual ~CG(); + /** * @brief Run CG. * @param out Solution vector. * @param in Right-hand side. */ - void operator()(ColorSpinorField &out, ColorSpinorField &in) override { (*this)(out, in, nullptr, 0.0); }; + void operator()(cvector_ref &out, cvector_ref &in) override + { + std::vector tmp(in.size()); + operator()(out, in, tmp, vector(in.size(), 0.0)); + } /** * @brief Solve re-using an initial Krylov space defined by an initial r2_old_init and search direction p_init. @@ -739,14 +743,15 @@ namespace quda { * @param p_init Initial-search direction. * @param r2_old_init [description] */ - void operator()(ColorSpinorField &out, ColorSpinorField &in, ColorSpinorField *p_init, double r2_old_init); + void operator()(cvector_ref &out, cvector_ref &in, + cvector_ref &p_init, cvector &r2_old_init); void blocksolve(ColorSpinorField &out, ColorSpinorField &in) override; /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return true; } /** CG is only for Hermitian systems */ @@ -758,19 +763,19 @@ namespace quda { * @param out Solution-vector. * @param in Right-hand side. */ - void hqsolve(ColorSpinorField &out, ColorSpinorField &in); + void hqsolve(cvector_ref &out, cvector_ref &in); }; - class CGNE : public CG + class CGNE : public Solver { - private: DiracMMdag mmdag; DiracMMdag mmdagSloppy; DiracMMdag mmdagPrecon; DiracMMdag mmdagEig; - ColorSpinorField xp; - ColorSpinorField yp; + std::unique_ptr cg; + std::vector xe; + std::vector ye; bool init = false; /** @@ -778,33 +783,33 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: CGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const final { return false; } /** CGNE is for any system */ virtual QudaInverterType getInverterType() const final { return QUDA_CGNE_INVERTER; } }; - class CGNR : public CG + class CGNR : public Solver { - private: DiracMdagM mdagm; DiracMdagM mdagmSloppy; DiracMdagM mdagmPrecon; DiracMdagM mdagmEig; - ColorSpinorField br; + std::unique_ptr cg; + std::vector br; bool init = false; /** @@ -812,18 +817,18 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: CGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const final { return false; } /** CGNR is for any system */ @@ -834,14 +839,14 @@ namespace quda { { private: - ColorSpinorField y; - ColorSpinorField r; - ColorSpinorField tmp; - ColorSpinorField ArS; - ColorSpinorField rS; - ColorSpinorField xS; - ColorSpinorField xS_old; - ColorSpinorField rS_old; + std::vector y; + std::vector r; + std::vector tmp; + std::vector ArS; + std::vector rS; + std::vector xS; + std::vector xS_old; + std::vector rS_old; bool init = false; /** @@ -849,103 +854,36 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: CG3(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return true; } /** CG is only for Hermitian systems */ virtual QudaInverterType getInverterType() const override { return QUDA_CG3_INVERTER; } }; - class CG3NE : public CG3 - { - - private: - DiracMMdag mmdag; - DiracMMdag mmdagSloppy; - DiracMMdag mmdagPrecon; - ColorSpinorField xp; - ColorSpinorField yp; - bool init = false; - - /** - @brief Initiate the fields needed by the solver - @param[in] x Solution vector - @param[in] b Source vector - */ - void create(ColorSpinorField &x, const ColorSpinorField &b); - - public: - CG3NE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m); - - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; - - /** - @return Return the residual vector from the prior solve - */ - ColorSpinorField &get_residual() override; - - virtual bool hermitian() const final { return false; } /** CG3NE is for any system */ - - virtual QudaInverterType getInverterType() const final { return QUDA_CG3NE_INVERTER; } - }; - - class CG3NR : public CG3 + class PCG : public Solver { - - private: - DiracMdagM mdagm; - DiracMdagM mdagmSloppy; - DiracMdagM mdagmPrecon; - ColorSpinorField br; - bool init = false; - - /** - @brief Initiate the fields needed by the solver - @param[in] x Solution vector - @param[in] b Source vector - */ - void create(ColorSpinorField &x, const ColorSpinorField &b); - - public: - CG3NR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m); - - void operator()(ColorSpinorField &out, ColorSpinorField &in); - - /** - @return Return the residual vector from the prior solve - */ - ColorSpinorField &get_residual(); - - virtual bool hermitian() const final { return false; } /** CG3NR is for any system */ - - virtual QudaInverterType getInverterType() const final { return QUDA_CG3NR_INVERTER; } - }; - - class PreconCG : public Solver { - private: std::shared_ptr K; SolverParam Kparam; // parameters for preconditioner solve - ColorSpinorField r; - ColorSpinorField y; - ColorSpinorField Ap; - ColorSpinorField x_sloppy; - ColorSpinorField r_sloppy; - ColorSpinorField minvr; - ColorSpinorField minvr_sloppy; - ColorSpinorField minvr_pre; - ColorSpinorField r_pre; - XUpdateBatch x_update_batch; + std::vector r; + std::vector y; + std::vector Ap; + std::vector x_sloppy; + std::vector r_sloppy; + std::vector minvr_sloppy; + std::vector minvr_pre; + std::vector r_pre; int Np; /** the size of the accumulator pipeline */ bool init = false; @@ -955,11 +893,11 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: - PreconCG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m); + PCG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, + SolverParam ¶m); /** * @brief Preconditioned CG supporting a pre-existing preconditioner K. @@ -970,14 +908,14 @@ namespace quda { * @param matEig Deflation precision Dirac matrix * @param param Solver parameters */ - PreconCG(const DiracMatrix &mat, Solver &K, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m); + PCG(const DiracMatrix &mat, Solver &K, const DiracMatrix &matSloppy, const DiracMatrix &matoPrecon, + const DiracMatrix &matEig, SolverParam ¶m); - virtual ~PreconCG(); + virtual ~PCG(); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override + void operator()(cvector_ref &out, cvector_ref &in) override { - this->solve_and_collect(out, in, cvector_ref(), 0, 0); + solve_and_collect(out, in, {}, 0, 0); } /** @@ -988,28 +926,27 @@ namespace quda { @param collect_miniter minimal iteration start from which the r vectors are to be collected @param collect_tol maxiter tolerance start from which the r vectors are to be collected */ - virtual void solve_and_collect(ColorSpinorField &out, ColorSpinorField &in, cvector_ref &v_r, - int collect_miniter, double collect_tol) override; + virtual void solve_and_collect(cvector_ref &out, cvector_ref &in, + cvector_ref &v_r, int collect_miniter, double collect_tol) override; virtual bool hermitian() const override { return true; } /** PCG is only Hermitian system */ virtual QudaInverterType getInverterType() const final { return QUDA_PCG_INVERTER; } }; - class BiCGstab : public Solver { private: const DiracMdagM matMdagM; // used by the eigensolver - ColorSpinorField y; // Full precision solution accumulator - ColorSpinorField r; // Full precision residual vector - ColorSpinorField p; // Sloppy precision search direction - ColorSpinorField v; // Sloppy precision A * p - ColorSpinorField t; // Sloppy precision vector used for minres step - ColorSpinorField r0; // Bi-orthogonalization vector - ColorSpinorField r_sloppy; // Slopy precision residual vector - ColorSpinorField x_sloppy; // Sloppy solution accumulator vector + std::vector y; // Full precision solution accumulator + std::vector r; // Full precision residual vector + std::vector p; // Sloppy precision search direction + std::vector v; // Sloppy precision A * p + std::vector t; // Sloppy precision vector used for minres step + std::vector r0; // Bi-orthogonalization vector + std::vector r_sloppy; // Slopy precision residual vector + std::vector x_sloppy; // Sloppy solution accumulator vector bool init = false; /** @@ -1017,19 +954,19 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: BiCGstab(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); virtual ~BiCGstab(); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return false; } /** BiCGStab is for any linear system */ @@ -1042,7 +979,6 @@ namespace quda { */ class BiCGstabL : public Solver { - private: const DiracMdagM matMdagM; // used by the eigensolver /** @@ -1052,41 +988,46 @@ namespace quda { int pipeline; // pipelining factor for legacyGramSchmidt // Various coefficients and params needed on each iteration. - Complex rho0, rho1, alpha, omega, beta; // Various coefficients for the BiCG part of BiCGstab-L. - std::vector gamma, gamma_prime, gamma_prime_prime; // Parameters for MR part of BiCGstab-L. (L+1) length. - std::vector tau; // Parameters for MR part of BiCGstab-L. Tech. modified Gram-Schmidt coeffs. (L+1)x(L+1) length. - std::vector sigma; // Parameters for MR part of BiCGstab-L. Tech. the normalization part of Gram-Scmidt. (L+1) length. + vector rho0, rho1, alpha, omega, beta; // Various coefficients for the BiCG part of BiCGstab-L. + vector> gamma, gamma_prime, gamma_prime_prime; // Parameters for MR part of BiCGstab-L. (L+1) length. + vector> tau; // Parameters for MR part of BiCGstab-L. Tech. modified Gram-Schmidt coeffs. (L+1)x(L+1) length. + vector> + sigma; // Parameters for MR part of BiCGstab-L. Tech. the normalization part of Gram-Scmidt. (L+1) length. - ColorSpinorField r_full; //! Full precision residual. - ColorSpinorField y; //! Full precision temporary. + std::vector r_full; //! Full precision residual. + std::vector y; //! Full precision temporary. // sloppy precision fields - ColorSpinorField temp; //! Sloppy temporary vector. - std::vector r; // Current residual + intermediate residual values, along the MR. - std::vector u; // Search directions. + std::vector temp; //! Sloppy temporary vector. + std::vector> r; // Current residual + intermediate residual values, along the MR. + std::vector> u; // Search directions. - ColorSpinorField x_sloppy; //! Sloppy solution vector. - ColorSpinorField r0; //! Shadow residual, in BiCG language. + std::vector x_sloppy; //! Sloppy solution vector. + std::vector r0; //! Shadow residual, in BiCG language. /** @brief Allocate persistent fields and parameter checking - @param[in] x Solution vector - @param[in] b Source vector + @param[in] x Solution vector set + @param[in] b Source vector set */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); /** - @brief Internal routine for reliable updates. Made to not conflict with BiCGstab's implementation. + @brief Internal routine for reliable updates. Made to not conflict with BiCGstab's implementation. */ int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta); /** * @brief Internal routine for performing the MR part of BiCGstab-L * - * @param x_sloppy [out] sloppy accumulator for x - * @param fixed_iteration [in] whether or not this is for a fixed iteration solver + * @param[in,out] x_sloppy sloppy accumulator for x + * @param[in,out] u search directions + * @param[in,out] residual vectors + * @param[in] fixed_iteration whether or not this is for a fixed iteration solver + * @param[in] src_idx which src we are presently working on */ - void computeMR(ColorSpinorField &x_sloppy, bool fixed_iteration); + void computeMR(ColorSpinorField &x_sloppy, cvector_ref &u, cvector_ref &r, + bool fixed_iteration, int src_idx); /** Legacy routines that encapsulate the original pipelined Gram-Schmit. @@ -1100,29 +1041,36 @@ namespace quda { * @brief Internal routine that comptues the "tau" matrix as described in * the original BiCGstab-L paper, supporting pipelining * - * @param begin [in] begin offset for pipelining - * @param size [in] length of pipelining - * @param j [in] row of tau being computed + * @param[in] begin begin offset for pipelining + * @param[in] size length of pipelining + * @param[in] j row of tau being computed + * @param[in] src_idx which src we are presently working on */ - void computeTau(int begin, int size, int j); + void computeTau(int begin, int size, int j, cvector_ref &r, int src_idx); /** * @brief Internal routine that updates R as described in * the original BiCGstab-L paper, supporting pipelining. * - * @param begin [in] begin offset for pipelining - * @param size [in] length of pipelining - * @param j [in] row of tau being computed + * @param[in] begin begin offset for pipelining + * @param[in] size length of pipelining + * @param[in] j row of tau being computed + * @param[in,out] r Residual vector set + * @param[in] src_idx which src we are presently working on */ - void updateR(int begin, int size, int j); + void updateR(int begin, int size, int j, cvector_ref &r, int src_idx); /** * @brief Internal legacy routine for performing the MR part of BiCGstab-L * which more closely matches the paper * - * @param x_sloppy [out] sloppy accumulator for x + * @param[in,out] x_sloppy sloppy accumulator for x + * @param[in,out] u Direction vector set + * @param[in,out] r Residual vector set + * @param[in] src_idx which src we are presently working on */ - void legacyComputeMR(ColorSpinorField &x_sloppy); + void legacyComputeMR(ColorSpinorField &x_sloppy, cvector_ref &u, cvector_ref &r, + int src_idx); /** Solver uses lazy allocation: this flag determines whether we have allocated or not. @@ -1135,7 +1083,7 @@ namespace quda { BiCGstabL(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matEig, SolverParam ¶m); virtual ~BiCGstabL(); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override; virtual bool hermitian() const override { return false; } /** BiCGStab is for any linear system */ @@ -1155,35 +1103,36 @@ namespace quda { */ int n_krylov; - std::vector alpha; - std::vector beta; - std::vector gamma; + std::vector> alpha; + std::vector> beta; + std::vector> gamma; /** Solver uses lazy allocation: this flag to determine whether we have allocated. */ bool init = false; - ColorSpinorField r; //! residual vector - ColorSpinorField r_sloppy; //! sloppy residual vector + std::vector r; //! residual vector + std::vector r_sloppy; //! sloppy residual vector - std::vector p; // GCR direction vectors - std::vector Ap; // mat * direction vectors + int k_break = 0; //! track when the solver converged + std::vector> p; // GCR direction vectors + std::vector> Ap; // mat * direction vectors - void computeBeta(std::vector &beta, std::vector &Ap, int i, int N, int k); - void updateAp(std::vector &beta, std::vector &Ap, int begin, int size, int k); - void orthoDir(std::vector &beta, std::vector &Ap, int k, int pipeline); + void computeBeta(std::vector &beta, cvector_ref &Ap, int i, int N, int k); + void updateAp(std::vector &beta, cvector_ref &Ap, int begin, int size, int k); + void orthoDir(std::vector &beta, cvector_ref &Ap, int k, int pipeline); void backSubs(const std::vector &alpha, const std::vector &beta, const std::vector &gamma, std::vector &delta, int n); void updateSolution(ColorSpinorField &x, const std::vector &alpha, const std::vector &beta, - std::vector &gamma, int k, std::vector &p); + std::vector &gamma, int k, cvector_ref &p); /** @brief Initiate the fields needed by the solver - @param[in] x Solution vector - @param[in] b Source vector + @param[in] x Solution vector set + @param[in] b Source vector set */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: GCR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, @@ -1196,7 +1145,12 @@ namespace quda { const DiracMatrix &matEig, SolverParam ¶m); virtual ~GCR(); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override; + + /** + @return Return the residual vector from the prior solve + */ + cvector_ref get_residual() override; virtual bool hermitian() const override { return false; } /** GCR is for any linear system */ @@ -1206,10 +1160,10 @@ namespace quda { class MR : public Solver { private: - ColorSpinorField r; - ColorSpinorField r_sloppy; - ColorSpinorField Ar; - ColorSpinorField x_sloppy; + std::vector r; + std::vector r_sloppy; + std::vector Ar; + std::vector x_sloppy; bool init = false; /** @@ -1217,17 +1171,17 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: MR(const DiracMatrix &mat, const DiracMatrix &matSloppy, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return false; } /** MR is for any linear system */ @@ -1250,130 +1204,64 @@ namespace quda { bool lambda_init; QudaCABasis basis; - std::vector Q_AQandg; // Fused inner product matrix - std::vector Q_AS; // inner product matrix - std::vector alpha; // QAQ^{-1} g - std::vector beta; // QAQ^{-1} QpolyS + std::vector> Q_AQandg; // Fused inner product matrix + std::vector> Q_AS; // inner product matrix + std::vector> alpha; // QAQ^{-1} g + std::vector> beta; // QAQ^{-1} QpolyS - ColorSpinorField r; + std::vector r; - std::vector S; // residual vectors - std::vector AS; // mat * residual vectors. Can be replaced by a single temporary. - std::vector Q; // CG direction vectors - std::vector Qtmp; // CG direction vectors for pointer swap - std::vector AQ; // mat * CG direction vectors. - // it's possible to avoid carrying these - // around, but there's a stability penalty, - // and computing QAQ becomes a pain (though - // it does let you fuse the reductions...) + std::vector> S; // residual vectors + std::vector> AS; // mat * residual vectors. Can be replaced by a single temporary. + std::vector> Q; // CG direction vectors + std::vector> Qtmp; // CG direction vectors for pointer swap + std::vector> AQ; // mat * CG direction vectors. + // it's possible to avoid carrying these + // around, but there's a stability penalty, + // and computing QAQ becomes a pain (though + // it does let you fuse the reductions...) /** @brief Initiate the fields needed by the solver @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); /** @brief Compute the alpha coefficients + @param[in] b batch number */ - void compute_alpha(); + void compute_alpha(int b); /** @brief Compute the beta coefficients + @param[in] b batch number */ - void compute_beta(); + void compute_beta(int b); /** - @ brief Check if it's time for a reliable update + @brief Check if it's time for a reliable update */ - int reliable(double &rNorm, double &maxrr, int &rUpdate, const double &r2, const double &delta); + int reliable(double &rNorm, double &maxrr, int &rUpdate, const double &r2, const double &delta); public: CACG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); virtual ~CACG(); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return true; } /** CG is only for Hermitian systems */ virtual QudaInverterType getInverterType() const override { return QUDA_CA_CG_INVERTER; } }; - class CACGNE : public CACG { - - private: - DiracMMdag mmdag; - DiracMMdag mmdagSloppy; - DiracMMdag mmdagPrecon; - DiracMMdag mmdagEig; - ColorSpinorField xp; - ColorSpinorField yp; - bool init = false; - - /** - @brief Initiate the fields needed by the solver - @param[in] x Solution vector - @param[in] b Source vector - */ - void create(ColorSpinorField &x, const ColorSpinorField &b); - - public: - CACGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m); - - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; - - /** - @return Return the residual vector from the prior solve - */ - ColorSpinorField &get_residual() override; - - virtual bool hermitian() const final { return false; } /** CA-CGNE is for any linear system */ - - virtual QudaInverterType getInverterType() const final { return QUDA_CA_CGNE_INVERTER; } - }; - - class CACGNR : public CACG - { - - private: - DiracMdagM mdagm; - DiracMdagM mdagmSloppy; - DiracMdagM mdagmPrecon; - DiracMdagM mdagmEig; - ColorSpinorField br; - bool init = false; - - /** - @brief Initiate the fields needed by the solver - @param[in] x Solution vector - @param[in] b Source vector - */ - void create(ColorSpinorField &x, const ColorSpinorField &b); - - public: - CACGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m); - - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; - - /** - @return Return the residual vector from the prior solve - */ - ColorSpinorField &get_residual() override; - - virtual bool hermitian() const final { return false; } /** CA-CGNR is for any linear system */ - - virtual QudaInverterType getInverterType() const final { return QUDA_CA_CGNR_INVERTER; } - }; - /** @brief Communication-avoiding GCR solver. This solver does un-preconditioned GCR, first building up a polynomial in the @@ -1390,19 +1278,19 @@ namespace quda { bool lambda_init; // whether or not lambda_max has been initialized QudaCABasis basis; // CA basis - std::vector alpha; // Solution coefficient vectors + std::vector> alpha; // Solution coefficient vectors - ColorSpinorField r; + std::vector r; - std::vector p; // GCR direction vectors - std::vector q; // mat * direction vectors + std::vector> p; // GCR direction vectors + std::vector> q; // mat * direction vectors /** @brief Initiate the fields needed by the solver @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); /** @brief Solve the equation A p_k psi_k = q_k psi_k = b by minimizing the @@ -1411,19 +1299,19 @@ namespace quda { @param[in] q Search direction vectors with the operator applied @param[in] b Source vector against which we are solving */ - void solve(std::vector &psi, std::vector &q, ColorSpinorField &b); + void solve(std::vector &psi, cvector_ref &q, ColorSpinorField &b); public: CAGCR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); virtual ~CAGCR(); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return false; } /** GCR is for any linear system */ @@ -1433,8 +1321,8 @@ namespace quda { // Steepest descent solver used as a preconditioner class SD : public Solver { private: - ColorSpinorField Ar; - ColorSpinorField r; + std::vector Ar; + std::vector r; bool init = false; /** @@ -1442,17 +1330,17 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: SD(const DiracMatrix &mat, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return false; } /** SD is for any linear system */ @@ -1477,14 +1365,15 @@ namespace quda { virtual ~PreconditionedSolver() { delete solver; } - void operator()(ColorSpinorField &x, ColorSpinorField &b) override + void operator()(cvector_ref &x, cvector_ref &b) override { + if (x.size() != b.size()) errorQuda("Mismatched set sizes %lu != %lu", x.size(), b.size()); pushOutputPrefix(prefix); QudaSolutionType solution_type = b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? QUDA_MAT_SOLUTION : QUDA_MATPC_SOLUTION; - ColorSpinorField out; - ColorSpinorField in; + std::vector out(b.size()); + std::vector in(b.size()); if (dirac.hasSpecialMG()) { dirac.prepareSpecialMG(out, in, x, b, solution_type); @@ -1680,11 +1569,16 @@ namespace quda { void RestartVT(const double beta, const double rho); void UpdateVm(ColorSpinorField &res, double beta, double sqrtr2); // EigCG solver: - int eigCGsolve(ColorSpinorField &out, ColorSpinorField &in); + int eigCGsolve(ColorSpinorField &out, const ColorSpinorField &in); // InitCG solver: - int initCGsolve(ColorSpinorField &out, ColorSpinorField &in); + int initCGsolve(ColorSpinorField &out, const ColorSpinorField &in); // Incremental eigCG solver (for eigcg and initcg calls) - void operator()(ColorSpinorField &out, ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); virtual bool hermitian() const final { return true; } // EigCG is only for Hermitian systems @@ -1721,7 +1615,12 @@ namespace quda { virtual ~GMResDR(); //GMRES-DR solver - void operator()(ColorSpinorField &out, ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); // //GMRESDR method void RunDeflatedCycles (ColorSpinorField *out, ColorSpinorField *in, const double tol_threshold); @@ -1753,4 +1652,18 @@ namespace quda { */ bool is_ca_solver(QudaInverterType type); + /** + @brief Join the separate split-grid instances of + QudaInvertParam. This function places the computed residuals + for each solve from the split grids in the expected order. + This function expects we are using the default (global) + communuicator. + + @param[in, out] out The global joined instance of QudaInvertParam + @param[in] in The local split-grid instance of QudaInvertParam + @param[in] comm_key The CommKey that defines the split grid used + @param[in] split_rank The rank of the process when in split grid + */ + void joinInvertParam(QudaInvertParam &out, const QudaInvertParam &in, const CommKey &comm_key, int split_rank); + } // namespace quda diff --git a/include/invert_x_update.h b/include/invert_x_update.h index 6726275bde..64ec7caa33 100644 --- a/include/invert_x_update.h +++ b/include/invert_x_update.h @@ -43,6 +43,11 @@ namespace quda } } + XUpdateBatch(const XUpdateBatch &other) = default; + XUpdateBatch(XUpdateBatch &&other) = default; + XUpdateBatch &operator=(const XUpdateBatch &other) = default; + XUpdateBatch &operator=(XUpdateBatch &&other) = default; + /** @brief use the vectors currently stored and add to the given output field @param x the output field to add to diff --git a/include/kernels/blas_core.cuh b/include/kernels/blas_core.cuh index 969a31fcae..ad97210e2e 100644 --- a/include/kernels/blas_core.cuh +++ b/include/kernels/blas_core.cuh @@ -418,17 +418,17 @@ namespace quda static constexpr memory_access<1, 1, 1> read{ }; static constexpr memory_access<1, 1> write{ }; complex a[MAX_MULTI_RHS] = {}; - double3 *Ar3; + double4 *Ar4; caxpyxmazMR_(cvector &a, cvector &, cvector &) : - Ar3(static_cast(reducer::get_device_buffer())) + Ar4(static_cast(reducer::get_device_buffer())) { for (auto i = 0u; i < a.size(); i++) this->a[i] = a[i]; } template __device__ __host__ void operator()(T &x, T &y, T &z, T &, T &, int j) const { - auto ar3 = Ar3[j]; - auto aj = a[j].real() * complex((real)ar3.x, (real)ar3.y) * ((real)1.0 / (real)ar3.z); + auto ar4 = Ar4[j]; + auto aj = a[j].real() * complex((real)ar4.x, (real)ar4.y) * ((real)1.0 / (real)ar4.z); #pragma unroll for (int i = 0; i < x.size(); i++) { diff --git a/include/kernels/copy_color_spinor_mg.cuh b/include/kernels/copy_color_spinor_mg.cuh index 32b6d3de87..fefc6a1277 100644 --- a/include/kernels/copy_color_spinor_mg.cuh +++ b/include/kernels/copy_color_spinor_mg.cuh @@ -16,9 +16,7 @@ namespace quda { template CopyArg(ColorSpinorField &out, const ColorSpinorField &in, T1 *Out, T2 *In) : - kernel_param(in.VolumeCB()), - out(out, 1, Out), - in(in, 1, In) + kernel_param(dim3(in.VolumeCB(), nSpin, nColor)), out(out, 1, Out), in(in, 1, In) {} }; @@ -27,13 +25,9 @@ namespace quda { constexpr CopySpinor_(const Arg &arg) : arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } - __device__ __host__ inline void operator()(int x_cb) + __device__ __host__ inline void operator()(int x_cb, int s, int c) { - for (int s=0; s; + constexpr bool do_rescale = mma_t::do_rescale(); + static_assert(M % Arg::bM == 0, "M %% Arg::bM != 0.\n"); static_assert(N % Arg::bN == 0, "N %% Arg::bN != 0.\n"); static_assert(K % Arg::bK == 0, "K %% Arg::bK != 0.\n"); @@ -162,152 +164,464 @@ namespace quda backward_idx[d] = linkIndexHop(coord, arg.dim, d, -arg.nFace); } - auto dslash_forward_producer = [&](int d, float &scale_inv_a, float &scale_inv_b, int k_offset) { - const int fwd_idx = forward_idx[d]; + if constexpr (do_rescale) { + auto dslash_forward_producer = [&](int d, float &scale_inv_a, float &scale_inv_b, int k_offset) { + const int fwd_idx = forward_idx[d]; + + if (forward_exterior[d]) { + if constexpr (doHalo()) { + int ghost_idx = ghostFaceIndex<1>(coord, arg.dim, d, arg.nFace); + + auto a = arg.Y(Arg::dagger ? d : d + 4, parity, x_cb, 0, 0); + auto b = arg.halo.Ghost(d, 1, their_spinor_parity, ghost_idx, 0, 0, 0); + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; - if (forward_exterior[d]) { - if constexpr (doHalo()) { - int ghost_idx = ghostFaceIndex<1>(coord, arg.dim, d, arg.nFace); + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b_ghost, pipe); + pipe.producer_commit(); + } + } else if constexpr (doBulk()) { auto a = arg.Y(Arg::dagger ? d : d + 4, parity, x_cb, 0, 0); - auto b = arg.halo.Ghost(d, 1, their_spinor_parity, ghost_idx, 0, 0, 0); + auto b = arg.inA(their_spinor_parity, fwd_idx, 0, 0, 0); constexpr bool a_dagger = false; constexpr bool b_dagger = false; - using store_b_ghost_t = complex; - auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); - __syncthreads(); pipe.producer_acquire(); scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); - scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b_ghost, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); pipe.producer_commit(); } - } else if constexpr (doBulk()) { + }; + + auto dslash_forward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) -> float { + float rescale_factor; + if (forward_exterior[d]) { + if constexpr (doHalo()) { + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; + + using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.halo.Ghost(0, 0, 0, 0, 0, 0, 0)); + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + constexpr bool a_fixed = a_wrapper_t::fixed; + constexpr bool b_fixed = b_wrapper_t::fixed; + + pipe.consumer_wait(); + __syncthreads(); + float rescale_factor_a = a_loader.template tmp2s_rescale( + smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale( + smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + pipe.consumer_release(); + __syncthreads(); + rescale_factor = rescale_factor_a * rescale_factor_b; + } + } else if constexpr (doBulk()) { - auto a = arg.Y(Arg::dagger ? d : d + 4, parity, x_cb, 0, 0); - auto b = arg.inA(their_spinor_parity, fwd_idx, 0, 0, 0); - constexpr bool a_dagger = false; - constexpr bool b_dagger = false; + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; - __syncthreads(); - pipe.producer_acquire(); - scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); - scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); - pipe.producer_commit(); - } - }; + using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.inA(0, 0, 0, 0, 0)); + constexpr bool a_fixed = a_wrapper_t::fixed; + constexpr bool b_fixed = b_wrapper_t::fixed; - auto dslash_forward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) { - if (forward_exterior[d]) { - if constexpr (doHalo()) { - constexpr bool a_dagger = false; + pipe.consumer_wait(); + __syncthreads(); + float rescale_factor_a = a_loader.template tmp2s_rescale( + smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale( + smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + pipe.consumer_release(); + __syncthreads(); + rescale_factor = rescale_factor_a * rescale_factor_b; + } + return rescale_factor; + }; + + auto dslash_forward_compute = [&](int d, float rescale_factor) { + if (forward_exterior[d] && doHalo() || doBulk()) { + accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, rescale_factor); + } + }; + + auto dslash_backward_producer = [&](int d, float &scale_inv_a, float &scale_inv_b, int k_offset) { + const int back_idx = backward_idx[d]; + + if (backward_exterior[d]) { + if constexpr (doHalo()) { + const int ghost_idx = ghostFaceIndex<0>(coord, arg.dim, d, arg.nFace); + + auto a = arg.Y.Ghost(Arg::dagger ? d + 4 : d, 1 - parity, ghost_idx, 0, 0); + auto b = arg.halo.Ghost(d, 0, their_spinor_parity, ghost_idx, 0, 0, 0); + constexpr bool a_dagger = true; + constexpr bool b_dagger = false; + + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b_ghost, pipe); + pipe.producer_commit(); + } + } else if constexpr (doBulk()) { + const int gauge_idx = back_idx; + + auto a = arg.Y(Arg::dagger ? d + 4 : d, 1 - parity, gauge_idx, 0, 0); + auto b = arg.inA(their_spinor_parity, back_idx, 0, 0, 0); + constexpr bool a_dagger = true; + constexpr bool b_dagger = false; + + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); + pipe.producer_commit(); + } + }; + + auto dslash_backward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) -> float { + float rescale_factor; + if (backward_exterior[d]) { + if constexpr (doHalo()) { + constexpr bool a_dagger = true; + constexpr bool b_dagger = false; + + using a_wrapper_t = decltype(arg.Y.Ghost(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.halo.Ghost(0, 0, 0, 0, 0, 0, 0)); + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + constexpr bool a_fixed = a_wrapper_t::fixed; + constexpr bool b_fixed = b_wrapper_t::fixed; + + pipe.consumer_wait(); + __syncthreads(); + float rescale_factor_a = a_loader.template tmp2s_rescale( + smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale( + smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + pipe.consumer_release(); + __syncthreads(); + rescale_factor = rescale_factor_a * rescale_factor_b; + } + } else if constexpr (doBulk()) { + constexpr bool a_dagger = true; constexpr bool b_dagger = false; using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); - using b_wrapper_t = decltype(arg.halo.Ghost(0, 0, 0, 0, 0, 0, 0)); - using store_b_ghost_t = complex; - auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + using b_wrapper_t = decltype(arg.inA(0, 0, 0, 0, 0)); constexpr bool a_fixed = a_wrapper_t::fixed; constexpr bool b_fixed = b_wrapper_t::fixed; pipe.consumer_wait(); __syncthreads(); - a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - b_loader.template tmp2s(smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, - smem_obj_b_imag); + float rescale_factor_a = a_loader.template tmp2s_rescale( + smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale( + smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); + rescale_factor = rescale_factor_a * rescale_factor_b; + } + return rescale_factor; + }; + + auto dslash_backward_compute = [&](int d, float rescale_factor) { + if (backward_exterior[d] && doHalo() || doBulk()) { + accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, rescale_factor); } - } else if constexpr (doBulk()) { + }; + + auto clover_producer = [&](float &scale_inv_a, float &scale_inv_b, int k_offset) { + const int spinor_parity = (arg.nParity == 2) ? parity : 0; + + auto a = arg.X(0, parity, x_cb, 0, 0); + auto b = arg.inB(spinor_parity, x_cb, 0, 0); + constexpr bool a_dagger = Arg::dagger; + constexpr bool b_dagger = false; + + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); + pipe.producer_commit(); + }; - constexpr bool a_dagger = false; + auto clover_consumer = [&](float scale_inv_a, float scale_inv_b) -> float { + constexpr bool a_dagger = Arg::dagger; constexpr bool b_dagger = false; - using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); - using b_wrapper_t = decltype(arg.inA(0, 0, 0, 0, 0)); + using a_wrapper_t = decltype(arg.X(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.inB(0, 0, 0, 0)); constexpr bool a_fixed = a_wrapper_t::fixed; constexpr bool b_fixed = b_wrapper_t::fixed; pipe.consumer_wait(); __syncthreads(); - a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - b_loader.template tmp2s(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + float rescale_factor_a = a_loader.template tmp2s_rescale( + smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale( + smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); + return rescale_factor_a * rescale_factor_b; + }; + + auto clover_compute = [&](float rescale_factor) { + accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, rescale_factor); + }; + + float scale_inv_a; + float scale_inv_b; + + if constexpr (Arg::dslash) { + + dslash_forward_producer(0, scale_inv_a, scale_inv_b, 0); + + for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { + + // Forward gather - compute fwd offset for spinor fetch +#pragma unroll + for (int d = 0; d < Arg::nDim; d++) // loop over dimension + { + float rescale_factor = dslash_forward_consumer(d, scale_inv_a, scale_inv_b); + if (d < 3) { + dslash_forward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); + } else { + dslash_backward_producer(0, scale_inv_a, scale_inv_b, k_offset); + } + dslash_forward_compute(d, rescale_factor); + } // nDim + + // Backward gather - compute back offset for spinor and gauge fetch +#pragma unroll + for (int d = 0; d < Arg::nDim; d++) { + float rescale_factor = dslash_backward_consumer(d, scale_inv_a, scale_inv_b); + if (d < 3) { + dslash_backward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); + } else if (k_offset + Arg::bK < K) { + dslash_forward_producer(0, scale_inv_a, scale_inv_b, k_offset + Arg::bK); + } else if constexpr (doBulk() && Arg::clover) { + clover_producer(scale_inv_a, scale_inv_b, 0); + } + dslash_backward_compute(d, rescale_factor); + } // nDim + } + + accumulator.ax(-arg.kappa); } - }; - auto dslash_forward_compute = [&](int d) { - if (forward_exterior[d] && doHalo() || doBulk()) { - accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); + /** + Applies the coarse clover matrix on a given parity and + checkerboard site index + */ + if constexpr (doBulk() && Arg::clover) { + if constexpr (!Arg::dslash) { clover_producer(scale_inv_a, scale_inv_b, 0); } + for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { + float rescale_factor = clover_consumer(scale_inv_a, scale_inv_b); + if (k_offset + Arg::bK < K) { clover_producer(scale_inv_a, scale_inv_b, k_offset + Arg::bK); } + clover_compute(rescale_factor); + } } - }; - auto dslash_backward_producer = [&](int d, float &scale_inv_a, float &scale_inv_b, int k_offset) { - const int back_idx = backward_idx[d]; + } else { - if (backward_exterior[d]) { - if constexpr (doHalo()) { - const int ghost_idx = ghostFaceIndex<0>(coord, arg.dim, d, arg.nFace); + auto dslash_forward_producer = [&](int d, float &scale_inv_a, float &scale_inv_b, int k_offset) { + const int fwd_idx = forward_idx[d]; - auto a = arg.Y.Ghost(Arg::dagger ? d + 4 : d, 1 - parity, ghost_idx, 0, 0); - auto b = arg.halo.Ghost(d, 0, their_spinor_parity, ghost_idx, 0, 0, 0); - constexpr bool a_dagger = true; - constexpr bool b_dagger = false; + if (forward_exterior[d]) { + if constexpr (doHalo()) { + int ghost_idx = ghostFaceIndex<1>(coord, arg.dim, d, arg.nFace); + + auto a = arg.Y(Arg::dagger ? d : d + 4, parity, x_cb, 0, 0); + auto b = arg.halo.Ghost(d, 1, their_spinor_parity, ghost_idx, 0, 0, 0); + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; + + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); - using store_b_ghost_t = complex; - auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b_ghost, pipe); + pipe.producer_commit(); + } + } else if constexpr (doBulk()) { + + auto a = arg.Y(Arg::dagger ? d : d + 4, parity, x_cb, 0, 0); + auto b = arg.inA(their_spinor_parity, fwd_idx, 0, 0, 0); + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; __syncthreads(); pipe.producer_acquire(); scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); - scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b_ghost, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); pipe.producer_commit(); } - } else if constexpr (doBulk()) { - const int gauge_idx = back_idx; + }; + + auto dslash_forward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) { + if (forward_exterior[d]) { + if constexpr (doHalo()) { + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; + + using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.halo.Ghost(0, 0, 0, 0, 0, 0, 0)); + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + constexpr bool a_fixed = a_wrapper_t::fixed; + constexpr bool b_fixed = b_wrapper_t::fixed; + + pipe.consumer_wait(); + __syncthreads(); + a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + b_loader.template tmp2s(smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, + smem_obj_b_imag); + pipe.consumer_release(); + __syncthreads(); + } + } else if constexpr (doBulk()) { - auto a = arg.Y(Arg::dagger ? d + 4 : d, 1 - parity, gauge_idx, 0, 0); - auto b = arg.inA(their_spinor_parity, back_idx, 0, 0, 0); - constexpr bool a_dagger = true; - constexpr bool b_dagger = false; + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; - __syncthreads(); - pipe.producer_acquire(); - scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); - scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); - pipe.producer_commit(); - } - }; + using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.inA(0, 0, 0, 0, 0)); + constexpr bool a_fixed = a_wrapper_t::fixed; + constexpr bool b_fixed = b_wrapper_t::fixed; - auto dslash_backward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) { - if (backward_exterior[d]) { - if constexpr (doHalo()) { + pipe.consumer_wait(); + __syncthreads(); + a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + b_loader.template tmp2s(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + pipe.consumer_release(); + __syncthreads(); + } + }; + + auto dslash_forward_compute = [&](int d) { + if (forward_exterior[d] && doHalo() || doBulk()) { + accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); + } + }; + + auto dslash_backward_producer = [&](int d, float &scale_inv_a, float &scale_inv_b, int k_offset) { + const int back_idx = backward_idx[d]; + + if (backward_exterior[d]) { + if constexpr (doHalo()) { + const int ghost_idx = ghostFaceIndex<0>(coord, arg.dim, d, arg.nFace); + + auto a = arg.Y.Ghost(Arg::dagger ? d + 4 : d, 1 - parity, ghost_idx, 0, 0); + auto b = arg.halo.Ghost(d, 0, their_spinor_parity, ghost_idx, 0, 0, 0); + constexpr bool a_dagger = true; + constexpr bool b_dagger = false; + + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b_ghost, pipe); + pipe.producer_commit(); + } + } else if constexpr (doBulk()) { + const int gauge_idx = back_idx; + + auto a = arg.Y(Arg::dagger ? d + 4 : d, 1 - parity, gauge_idx, 0, 0); + auto b = arg.inA(their_spinor_parity, back_idx, 0, 0, 0); constexpr bool a_dagger = true; constexpr bool b_dagger = false; - using a_wrapper_t = decltype(arg.Y.Ghost(0, 0, 0, 0, 0)); - using b_wrapper_t = decltype(arg.halo.Ghost(0, 0, 0, 0, 0, 0, 0)); - using store_b_ghost_t = complex; - auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); + pipe.producer_commit(); + } + }; + + auto dslash_backward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) { + if (backward_exterior[d]) { + if constexpr (doHalo()) { + constexpr bool a_dagger = true; + constexpr bool b_dagger = false; + + using a_wrapper_t = decltype(arg.Y.Ghost(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.halo.Ghost(0, 0, 0, 0, 0, 0, 0)); + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + constexpr bool a_fixed = a_wrapper_t::fixed; + constexpr bool b_fixed = b_wrapper_t::fixed; + + pipe.consumer_wait(); + __syncthreads(); + a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + b_loader.template tmp2s(smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, + smem_obj_b_imag); + pipe.consumer_release(); + __syncthreads(); + } + } else if constexpr (doBulk()) { + constexpr bool a_dagger = true; + constexpr bool b_dagger = false; + + using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.inA(0, 0, 0, 0, 0)); constexpr bool a_fixed = a_wrapper_t::fixed; constexpr bool b_fixed = b_wrapper_t::fixed; pipe.consumer_wait(); __syncthreads(); a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - b_loader.template tmp2s(smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, - smem_obj_b_imag); + b_loader.template tmp2s(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); } - } else if constexpr (doBulk()) { - constexpr bool a_dagger = true; + }; + + auto dslash_backward_compute = [&](int d) { + if (backward_exterior[d] && doHalo() || doBulk()) { + accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); + } + }; + + auto clover_producer = [&](float &scale_inv_a, float &scale_inv_b, int k_offset) { + const int spinor_parity = (arg.nParity == 2) ? parity : 0; + + auto a = arg.X(0, parity, x_cb, 0, 0); + auto b = arg.inB(spinor_parity, x_cb, 0, 0); + constexpr bool a_dagger = Arg::dagger; + constexpr bool b_dagger = false; + + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); + pipe.producer_commit(); + }; + + auto clover_consumer = [&](float scale_inv_a, float scale_inv_b) { + constexpr bool a_dagger = Arg::dagger; constexpr bool b_dagger = false; - using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); - using b_wrapper_t = decltype(arg.inA(0, 0, 0, 0, 0)); + using a_wrapper_t = decltype(arg.X(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.inB(0, 0, 0, 0)); constexpr bool a_fixed = a_wrapper_t::fixed; constexpr bool b_fixed = b_wrapper_t::fixed; @@ -317,102 +631,64 @@ namespace quda b_loader.template tmp2s(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); - } - }; - - auto dslash_backward_compute = [&](int d) { - if (backward_exterior[d] && doHalo() || doBulk()) { - accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); - } - }; - - auto clover_producer = [&](float &scale_inv_a, float &scale_inv_b, int k_offset) { - const int spinor_parity = (arg.nParity == 2) ? parity : 0; - - auto a = arg.X(0, parity, x_cb, 0, 0); - auto b = arg.inB(spinor_parity, x_cb, 0, 0); - constexpr bool a_dagger = Arg::dagger; - constexpr bool b_dagger = false; - - __syncthreads(); - pipe.producer_acquire(); - scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); - scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); - pipe.producer_commit(); - }; - - auto clover_consumer = [&](float scale_inv_a, float scale_inv_b) { - constexpr bool a_dagger = Arg::dagger; - constexpr bool b_dagger = false; - - using a_wrapper_t = decltype(arg.X(0, 0, 0, 0, 0)); - using b_wrapper_t = decltype(arg.inB(0, 0, 0, 0)); - constexpr bool a_fixed = a_wrapper_t::fixed; - constexpr bool b_fixed = b_wrapper_t::fixed; - - pipe.consumer_wait(); - __syncthreads(); - a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - b_loader.template tmp2s(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); - pipe.consumer_release(); - __syncthreads(); - }; + }; - auto clover_compute = [&]() { accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); }; + auto clover_compute + = [&]() { accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); }; - float scale_inv_a; - float scale_inv_b; + float scale_inv_a; + float scale_inv_b; - if constexpr (Arg::dslash) { + if constexpr (Arg::dslash) { - dslash_forward_producer(0, scale_inv_a, scale_inv_b, 0); + dslash_forward_producer(0, scale_inv_a, scale_inv_b, 0); - for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { + for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { - // Forward gather - compute fwd offset for spinor fetch + // Forward gather - compute fwd offset for spinor fetch #pragma unroll - for (int d = 0; d < Arg::nDim; d++) // loop over dimension - { - dslash_forward_consumer(d, scale_inv_a, scale_inv_b); - if (d < 3) { - dslash_forward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); - } else { - dslash_backward_producer(0, scale_inv_a, scale_inv_b, k_offset); - } - dslash_forward_compute(d); - } // nDim - - // Backward gather - compute back offset for spinor and gauge fetch + for (int d = 0; d < Arg::nDim; d++) // loop over dimension + { + dslash_forward_consumer(d, scale_inv_a, scale_inv_b); + if (d < 3) { + dslash_forward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); + } else { + dslash_backward_producer(0, scale_inv_a, scale_inv_b, k_offset); + } + dslash_forward_compute(d); + } // nDim + + // Backward gather - compute back offset for spinor and gauge fetch #pragma unroll - for (int d = 0; d < Arg::nDim; d++) { - dslash_backward_consumer(d, scale_inv_a, scale_inv_b); - if (d < 3) { - dslash_backward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); - } else if (k_offset + Arg::bK < K) { - dslash_forward_producer(0, scale_inv_a, scale_inv_b, k_offset + Arg::bK); - } else if constexpr (doBulk() && Arg::clover) { - clover_producer(scale_inv_a, scale_inv_b, 0); - } - dslash_backward_compute(d); - } // nDim - } + for (int d = 0; d < Arg::nDim; d++) { + dslash_backward_consumer(d, scale_inv_a, scale_inv_b); + if (d < 3) { + dslash_backward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); + } else if (k_offset + Arg::bK < K) { + dslash_forward_producer(0, scale_inv_a, scale_inv_b, k_offset + Arg::bK); + } else if constexpr (doBulk() && Arg::clover) { + clover_producer(scale_inv_a, scale_inv_b, 0); + } + dslash_backward_compute(d); + } // nDim + } - accumulator.ax(-arg.kappa); - } + accumulator.ax(-arg.kappa); + } - /** - Applies the coarse clover matrix on a given parity and - checkerboard site index - */ - if constexpr (doBulk() && Arg::clover) { - if constexpr (!Arg::dslash) { clover_producer(scale_inv_a, scale_inv_b, 0); } - for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { - clover_consumer(scale_inv_a, scale_inv_b); - if (k_offset + Arg::bK < K) { clover_producer(scale_inv_a, scale_inv_b, k_offset + Arg::bK); } - clover_compute(); + /** + Applies the coarse clover matrix on a given parity and + checkerboard site index + */ + if constexpr (doBulk() && Arg::clover) { + if constexpr (!Arg::dslash) { clover_producer(scale_inv_a, scale_inv_b, 0); } + for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { + clover_consumer(scale_inv_a, scale_inv_b); + if (k_offset + Arg::bK < K) { clover_producer(scale_inv_a, scale_inv_b, k_offset + Arg::bK); } + clover_compute(); + } } } - return accumulator; } diff --git a/include/monitor.h b/include/monitor.h index 24ac11ee95..1b06d2ffbd 100644 --- a/include/monitor.h +++ b/include/monitor.h @@ -1,3 +1,5 @@ +#include "device.h" + namespace quda { @@ -21,6 +23,25 @@ namespace quda */ void serialize(); + /** + @brief Get the current size of the monitor state. Used for + bookending a period for later analysis. + */ + size_t size(); + + struct state_t { + double energy = 0.0; + double power = 0.0; + double temp = 0.0; + double clock = 0.0; + }; + + /** + @brief Get the mean state observables between start and end, where + start and end are two intervals of history in the state. + */ + state_t mean(size_t start, size_t end); + } // namespace monitor } // namespace quda diff --git a/include/multigrid.h b/include/multigrid.h index 35253bd0a6..798d1abf4e 100644 --- a/include/multigrid.h +++ b/include/multigrid.h @@ -99,6 +99,9 @@ namespace quda { /** Number of vectors used to define coarse space */ int Nvec; + /** Batch size when computing null space vectors */ + int n_vec_batch; + /** Number of times to apply Gram-Schmidt within a block */ int NblockOrtho; @@ -180,6 +183,7 @@ namespace quda { Nlevel(param.n_level), spinBlockSize(param.spin_block_size[level]), Nvec(param.n_vec[level]), + n_vec_batch(param.n_vec_batch[level]), NblockOrtho(param.n_block_ortho[level]), blockOrthoTwoPass(param.block_ortho_two_pass[level]), B(B), @@ -216,6 +220,7 @@ namespace quda { Nlevel(param.Nlevel), spinBlockSize(param.mg_global.spin_block_size[level]), Nvec(param.mg_global.n_vec[level]), + n_vec_batch(param.mg_global.n_vec_batch[level]), NblockOrtho(param.mg_global.n_block_ortho[level]), blockOrthoTwoPass(param.mg_global.block_ortho_two_pass[level]), coarse(param.coarse), @@ -294,14 +299,14 @@ namespace quda { /** The coarse-grid representation of the null space vectors */ std::vector B_coarse; - /** Residual vector */ - ColorSpinorField r; + /** Residual vector set */ + std::vector r; - /** Coarse residual vector */ - ColorSpinorField r_coarse; + /** Coarse residual vector set */ + std::vector r_coarse; - /** Coarse solution vector */ - ColorSpinorField x_coarse; + /** Coarse solution vector set */ + std::vector x_coarse; /** Kahler-Dirac Xinv */ std::shared_ptr xInvKD; @@ -440,7 +445,7 @@ namespace quda { @param out The solution vector @param in The residual vector (or equivalently the right hand side vector) */ - void operator()(ColorSpinorField &out, ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in); /** @brief Load the null space vectors in from file diff --git a/include/quda.h b/include/quda.h index ed974dbe9f..3481818cc4 100644 --- a/include/quda.h +++ b/include/quda.h @@ -145,8 +145,8 @@ extern "C" { double tol_hq; /**< Solver tolerance in the heavy quark residual norm */ int compute_true_res; /** Whether to compute the true residual post solve */ - double true_res; /**< Actual L2 residual norm achieved in solver */ - double true_res_hq; /**< Actual heavy quark residual norm achieved in solver */ + double true_res[QUDA_MAX_MULTI_SRC]; /**< Actual L2 residual norm achieved in the solver */ + double true_res_hq[QUDA_MAX_MULTI_SRC]; /**< Actual heavy quark residual norm achieved in the solver */ int maxiter; /**< Maximum number of iterations in the linear solver */ double reliable_delta; /**< Reliable update tolerance */ double reliable_delta_refinement; /**< Reliable update tolerance used in post multi-shift solver refinement */ @@ -278,6 +278,10 @@ extern "C" { int iter; /**< The number of iterations performed by the solver */ double gflops; /**< The Gflops rate of the solver */ double secs; /**< The time taken by the solver */ + double energy; /**< The energy consumed by the solver */ + double power; /**< The mean power of the solver */ + double temp; /**< The mean temperature of the device for the duration of the solve */ + double clock; /**< The mean clock frequency of the device for the duration of the solve */ QudaTune tune; /**< Enable auto-tuning? (default = QUDA_TUNE_YES) */ @@ -550,6 +554,8 @@ extern "C" { int batched_rotate; /** For block method solvers, the block size **/ int block_size; + /** The batch size used when computing eigenvalues **/ + int compute_evals_batch_size; /** For block method solvers, quit after n attempts at block orthonormalisation **/ int max_ortho_attempts; /** For hybrid modifeld Gram-Schmidt orthonormalisations **/ @@ -602,12 +608,6 @@ extern "C" { /** Whether to save eigenvectors in QIO singlefile or partfile format */ QudaBoolean partfile; - /** The Gflops rate of the eigensolver setup */ - double gflops; - - /**< The time taken by the eigensolver setup */ - double secs; - /** Which external library to use in the deflation operations (Eigen) */ QudaExtLibType extlib_type; //------------------------------------------------- @@ -655,6 +655,9 @@ extern "C" { /** Inverter to use in the setup phase */ QudaInverterType setup_inv_type[QUDA_MAX_MG_LEVEL]; + /** Solver batch size to use in the setup phase */ + int n_vec_batch[QUDA_MAX_MG_LEVEL]; + /** Number of setup iterations */ int num_setup_iter[QUDA_MAX_MG_LEVEL]; @@ -805,12 +808,6 @@ extern "C" { /** Whether to preserve the deflation space during MG update */ QudaBoolean preserve_deflation; - /** The Gflops rate of the multigrid solver setup */ - double gflops; - - /**< The time taken by the multigrid solver setup */ - double secs; - /** Multiplicative factor for the mu parameter */ double mu_factor[QUDA_MAX_MG_LEVEL]; @@ -1819,6 +1816,10 @@ extern "C" { double secs; /** Flops count for the smearing operations **/ double gflops; + double energy; /**< The energy consumed by the smearing operations */ + double power; /**< The mean power of the smearing operations */ + double temp; /**< The mean temperature of the device for the duration of the smearing operations */ + double clock; /**< The mean clock frequency of the device for the duration of the smearing operations */ } QudaQuarkSmearParam; diff --git a/include/quda_constants.h b/include/quda_constants.h index 983a49a0aa..983795cba7 100644 --- a/include/quda_constants.h +++ b/include/quda_constants.h @@ -32,15 +32,9 @@ /** * @def QUDA_MAX_BLOCK_SRC - * @brief Maximum number of sources that can be supported by the block solver + * @brief Maximum number of sources that can be supported by the multi-src solver */ -#define QUDA_MAX_BLOCK_SRC 64 - -/** - * @def QUDA_MAX_ARRAY - * @brief Maximum array length used in QudaInvertParam arrays - */ -#define QUDA_MAX_ARRAY_SIZE (QUDA_MAX_MULTI_SHIFT > QUDA_MAX_BLOCK_SRC ? QUDA_MAX_MULTI_SHIFT : QUDA_MAX_BLOCK_SRC) +#define QUDA_MAX_MULTI_SRC 128 /** * @def QUDA_MAX_DWF_LS diff --git a/include/reference_wrapper_helper.h b/include/reference_wrapper_helper.h index a9eccf3648..ed27d17106 100644 --- a/include/reference_wrapper_helper.h +++ b/include/reference_wrapper_helper.h @@ -246,11 +246,8 @@ namespace quda @param[in] first Begin iterator @param[in] last End iterator */ - template >* = nullptr> - vector_ref(U first, U last) + template > * = nullptr> vector_ref(U first, U last) : vector(first, last) { - vector::reserve(last - first); - for (auto it = first; it != last; it++) vector::push_back(*it); } /** @@ -490,6 +487,16 @@ namespace quda vector() = default; vector(uint64_t size, const T &value = {}) : std::vector(size, value) { } + /** + Constructor from pair of iterators + @param[in] first Begin iterator + @param[in] last End iterator + */ + template > * = nullptr> + vector(U first, U last) : std::vector(first, last) + { + } + /** @brief Constructor using std::vector initialization @param[in] u Vector we are copying from @@ -530,13 +537,43 @@ namespace quda /** @brief Cast to scalar. Only works if the vector size is 1. */ - operator T() const + explicit operator T() const { if (std::vector::size() != 1) errorQuda("Cast to scalar failed since size = %lu", std::vector::size()); return std::vector::operator[](0); } + + bool operator<(const vector &v) const + { + for (auto i = 0u; i < v.size(); i++) + if (this->operator[](i) >= v[i]) return false; + return true; + } + + bool operator>(const vector &v) const + { + for (auto i = 0u; i < v.size(); i++) + if (this->operator[](i) <= v[i]) return false; + return true; + } + + vector operator-() const + { + vector negative(*this); + for (auto &v : negative) v = -v; + return negative; + } + + vector operator*(const T &u) const + { + vector multiplied(*this); + for (auto &v : multiplied) v *= u; + return multiplied; + } }; + template vector operator*(const T &a, const vector &b) { return b * a; } + template using cvector = const vector; } // namespace quda diff --git a/include/reliable_updates.h b/include/reliable_updates.h index a1ddd9f122..ec70fb4f3c 100644 --- a/include/reliable_updates.h +++ b/include/reliable_updates.h @@ -163,8 +163,9 @@ namespace quda pnorm = pnorm + alpha * alpha * ppnorm; xnorm = sqrt(pnorm); d_new = d + params.u * rNorm + params.uhigh * params.Anorm * xnorm; - if (steps_since_reliable == 0 && getVerbosity() >= QUDA_DEBUG_VERBOSE) - printfQuda("New dnew: %e (r %e , y %e)\n", d_new, params.u * rNorm, params.uhigh * params.Anorm * xnorm); + if (steps_since_reliable == 0) + logQuda(QUDA_DEBUG_VERBOSE, "New dnew: %e (r %e , y %e)\n", d_new, params.u * rNorm, + params.uhigh * params.Anorm * xnorm); } steps_since_reliable++; } diff --git a/include/split_grid.h b/include/split_grid.h index 19a37c4311..2bd42bb476 100644 --- a/include/split_grid.h +++ b/include/split_grid.h @@ -37,7 +37,7 @@ namespace quda things and the extension to 4d is trivial. */ - auto processor_dim = comm_grid_dim / comm_key; // How many processors are there in a processor grid sub-parititon? + auto processor_dim = comm_grid_dim / comm_key; // How many processors are there in a processor grid sub-partition? auto partition_dim = comm_grid_dim / processor_dim; // How many such sub-partitions are there? partition_dim == comm_key diff --git a/include/targets/cuda/mma_tensor_op/gemm.cuh b/include/targets/cuda/mma_tensor_op/gemm.cuh index 9f8eef8241..334675d495 100644 --- a/include/targets/cuda/mma_tensor_op/gemm.cuh +++ b/include/targets/cuda/mma_tensor_op/gemm.cuh @@ -6,6 +6,8 @@ #include #include +#include + namespace quda { namespace mma @@ -45,6 +47,8 @@ namespace quda reg_imag = 0; } + inline __device__ float abs_max(float a, float max) { return fmaxf(fabsf(a), max); } + /** @brief Load from global memory and store data in registers. */ @@ -80,6 +84,77 @@ namespace quda } } + /** + @brief Load from global memory and store data in registers while also applying a rescaling + */ + template + inline __device__ void convert_x_rescale(float ®_real, float ®_imag, complex *p, int m_idx, int n_idx, + float scale_inv, float rescale) + { + if (x) { + auto xx = p[m_idx * ld + n_idx]; + + if (fixed) { + reg_real = scale_inv * xx.real() * rescale; + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag = scale_inv_conj * xx.imag() * rescale; + } else { + reg_real = +xx.real() * rescale; + reg_imag = (dagger ? -xx.imag() : +xx.imag()) * rescale; + } + } else { + auto xx = p[n_idx * ld + m_idx]; + using store_type = T; + using store_array = typename VectorType::type; + store_array v = *reinterpret_cast(&p[n_idx * ld + m_idx]); + + if (fixed) { + reg_real = scale_inv * xx.real() * rescale; + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag = scale_inv_conj * xx.imag() * rescale; + } else { + reg_real = xx.real() * rescale; + reg_imag = (dagger ? -xx.imag() : xx.imag()) * rescale; + } + } + } + + /** + @brief Load from global memory and store data in registers. + */ + template + inline __device__ float find_abs_max(complex *p, int m_idx, int n_idx, float scale_inv) + { + float this_max = 0.0f; + + if (x) { + auto xx = p[m_idx * ld + n_idx]; + + if (fixed) { + this_max = abs_max(scale_inv * xx.real(), this_max); + this_max = abs_max(scale_inv * xx.imag(), this_max); + } else { + this_max = abs_max(xx.real(), this_max); + this_max = abs_max(xx.imag(), this_max); + } + } else { + auto xx = p[n_idx * ld + m_idx]; + using store_type = T; + using store_array = typename VectorType::type; + store_array v = *reinterpret_cast(&p[n_idx * ld + m_idx]); + + if (fixed) { + this_max = abs_max(scale_inv * xx.real(), this_max); + this_max = abs_max(scale_inv * xx.imag(), this_max); + } else { + this_max = abs_max(xx.real(), this_max); + this_max = abs_max(xx.imag(), this_max); + } + } + + return this_max; + } + /** @brief Load from global memory and store data in registers. */ @@ -196,6 +271,94 @@ namespace quda return gmem.get_scale_inv(); } + template + __device__ inline float tmp2s_rescale(complex *smem_ptr, float scale_inv, smem_accessor_t &smem_real, + smem_accessor_t &smem_imag) + { + // for each iteration, each warp loads a tile + int thread_id = (threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x; + int warp_id = thread_id / 32; + int lane_id = thread_id % 32; + int thread_in_group = lane_id % 4; + int group_id = lane_id / 4; + constexpr int w_m = 8 * batch; + constexpr int w_k = 4; + static_assert(bM % w_m == 0, "bM %% w_m"); + static_assert(bN % w_k == 0, "bN %% w_k"); + + constexpr int tile_dim_m = bM / w_m; + constexpr int tile_dim_k = bN / w_k; + + constexpr int total_tiles = tile_dim_k * tile_dim_m; + constexpr int n_warp = block_y * block_z / 32; + constexpr int warp_cycle = (total_tiles + n_warp - 1) / n_warp; + + float thread_max = 0.0f; + +#pragma unroll + for (int c = 0; c < warp_cycle; c++) { + int logical_warp_index = c * n_warp + warp_id; + if (logical_warp_index < total_tiles) { + int warp_m = (c * n_warp + warp_id) % tile_dim_m; + int warp_k = (c * n_warp + warp_id) / tile_dim_m; + + int smem_m_offset = warp_m * w_m + group_id * batch; + int smem_k_offset = warp_k * w_k + thread_in_group; + + int gmem_m_offset = smem_m_offset; + int gmem_k_offset = smem_k_offset; + + constexpr bool x = (transpose == dagger); + float this_max = find_abs_max < x, fixed, dagger, + x ? bN + 4 : bM + 4 > (smem_ptr, gmem_m_offset, gmem_k_offset, scale_inv); + thread_max = fmaxf(this_max, thread_max); + } + } + + // block all-reduce thread_max + using block_reduce_t = cub::BlockReduce; + __shared__ typename block_reduce_t::TempStorage temp_storage; + float block_max = block_reduce_t(temp_storage).Reduce(thread_max, cub::Max()); + + __shared__ float block_max_all; + if (threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z) == 0) { + if (block_max > 0.0f) { + block_max_all = block_max; + } else { + block_max_all = 1.0f; + } + } + __syncthreads(); + + float block_rescale_factor = 65504.0f / block_max_all; // 65504 = the maximum FP16 number + +#pragma unroll + for (int c = 0; c < warp_cycle; c++) { + int logical_warp_index = c * n_warp + warp_id; + if (logical_warp_index < total_tiles) { + int warp_m = (c * n_warp + warp_id) % tile_dim_m; + int warp_k = (c * n_warp + warp_id) / tile_dim_m; + + int smem_m_offset = warp_m * w_m + group_id * batch; + int smem_k_offset = warp_k * w_k + thread_in_group; + + int gmem_m_offset = smem_m_offset; + int gmem_k_offset = smem_k_offset; + + load_t real; + load_t imag; + + constexpr bool x = (transpose == dagger); + convert_x_rescale(real, imag, smem_ptr, gmem_m_offset, gmem_k_offset, + scale_inv, block_rescale_factor); + smem_real.vector_load(smem_m_offset, smem_k_offset, real); + smem_imag.vector_load(smem_m_offset, smem_k_offset, imag); + } + } + + return 1.0f / block_rescale_factor; + } + template __device__ inline void tmp2s(complex *smem_ptr, float scale_inv, smem_accessor_t &smem_real, smem_accessor_t &smem_imag) @@ -374,6 +537,38 @@ namespace quda } } + /** @brief Apply MMA, but doing a rescaling before accumulate into the final accumulator */ + template + __device__ inline void mma_rescale(const SmemObjA &smem_obj_a_real, const SmemObjA &smem_obj_a_imag, + const SmemObjB &smem_obj_b_real, const SmemObjB &smem_obj_b_imag, float rescale) + { + +#pragma unroll + for (int c = 0; c < warp_cycle; c++) { + typename mma_t::OperandC op_c_real_tmp; + op_c_real_tmp.zero(); + typename mma_t::OperandC op_c_imag_tmp; + op_c_imag_tmp.zero(); + +#pragma unroll 1 + for (int tile_k = 0; tile_k < tile_acc_dim; tile_k++) { + + // The logical warp assigned to each part of the matrix. + const int logical_warp_index = wrm.warp_id * warp_cycle + c; + if (logical_warp_index < tile_row_dim * tile_col_dim) { + const int warp_row = logical_warp_index / tile_col_dim; + const int warp_col = logical_warp_index - warp_row * tile_col_dim; + + complex_mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, op_c_real_tmp, + op_c_imag_tmp, warp_row, warp_col, tile_k, wrm); + } + } + op_c_real[c].axpy(rescale, op_c_real_tmp); + op_c_imag[c].axpy(rescale, op_c_imag_tmp); + } + } + + /** @brief Apply MMA */ template __device__ inline void mma(const SmemObjA &smem_obj_a_real, const SmemObjA &smem_obj_a_imag, const SmemObjB &smem_obj_b_real, const SmemObjB &smem_obj_b_imag) diff --git a/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh b/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh index b7851dc58f..39d99bc93a 100644 --- a/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh +++ b/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh @@ -22,6 +22,11 @@ namespace quda static __device__ __host__ constexpr int inline pad_size(int m) { return m == 48 ? 2 : 10; } + static constexpr bool do_rescale() + { + return true; // true because we use FP16 + } + static constexpr int MMA_M = 16; static constexpr int MMA_N = 16; static constexpr int MMA_K = 4; @@ -150,6 +155,12 @@ namespace quda for (int i = 0; i < 8; i++) { reg[i] *= alpha; } } + __device__ inline void axpy(float alpha, OperandC x) + { +#pragma unroll + for (int i = 0; i < 8; i++) { reg[i] += alpha * x.reg[i]; } + } + template __device__ inline void store(void *smem, int warp_row, int warp_col, const WarpRegisterMapping &wrm) { diff --git a/include/targets/cuda/mma_tensor_op/hmma_m16n8k8_sm80.cuh b/include/targets/cuda/mma_tensor_op/hmma_m16n8k8_sm80.cuh index b8e95a093c..2bdd83c74b 100644 --- a/include/targets/cuda/mma_tensor_op/hmma_m16n8k8_sm80.cuh +++ b/include/targets/cuda/mma_tensor_op/hmma_m16n8k8_sm80.cuh @@ -16,6 +16,11 @@ namespace quda static __device__ __host__ constexpr int inline pad_size(int m) { return m == 192 ? 0 : 8; } + static constexpr bool do_rescale() + { + return true; // true because we use FP16 + } + static constexpr int MMA_M = 16; static constexpr int MMA_N = 8; static constexpr int MMA_K = 8; diff --git a/include/targets/cuda/mma_tensor_op/hmma_tfloat32_sm80.cuh b/include/targets/cuda/mma_tensor_op/hmma_tfloat32_sm80.cuh index 5fb40699af..42ddc42c4b 100644 --- a/include/targets/cuda/mma_tensor_op/hmma_tfloat32_sm80.cuh +++ b/include/targets/cuda/mma_tensor_op/hmma_tfloat32_sm80.cuh @@ -11,6 +11,11 @@ namespace quda template struct hmma_tfloat32_t { + static constexpr bool do_rescale() + { + return false; // false because TF32 has the same range as FP32 + } + static constexpr int warp_m = warp_m_; static constexpr int warp_n = warp_n_; diff --git a/include/targets/cuda/mma_tensor_op/smma_m16n16k4_sm70.cuh b/include/targets/cuda/mma_tensor_op/smma_m16n16k4_sm70.cuh index 2d8fad7e90..c72bceae23 100644 --- a/include/targets/cuda/mma_tensor_op/smma_m16n16k4_sm70.cuh +++ b/include/targets/cuda/mma_tensor_op/smma_m16n16k4_sm70.cuh @@ -17,6 +17,11 @@ namespace quda static __device__ __host__ constexpr int inline pad_size(int) { return 0; } + static constexpr bool do_rescale() + { + return true; // false because we use FP16 + } + static constexpr int MMA_M = 16; static constexpr int MMA_N = 16; static constexpr int MMA_K = 4; diff --git a/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh b/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh index e41e089fb2..2c801a4f46 100644 --- a/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh +++ b/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh @@ -91,6 +91,11 @@ namespace quda static constexpr bool use_intermediate_accumulator() { return true; }; + static constexpr bool do_rescale() + { + return std::is_same_v ? true : false; // true if we use FP16 + } + static constexpr int warp_m = warp_m_; static constexpr int warp_n = warp_n_; @@ -314,6 +319,12 @@ namespace quda for (int i = 0; i < warp_m * warp_n * thread_count; i++) { reg[i] *= alpha; } } + __device__ inline void axpy(float alpha, OperandC x) + { +#pragma unroll + for (int i = 0; i < warp_m * warp_n * thread_count; i++) { reg[i] += alpha * x.reg[i]; } + } + template __device__ void store(void *ptr, int warp_row, int warp_col, const WarpRegisterMapping &wrm) { // This method is only used for the mobius preconditioner where shuffle_t = half. diff --git a/include/targets/cuda/reduce_helper.h b/include/targets/cuda/reduce_helper.h index 73fc0bfbc0..da3e178bcf 100644 --- a/include/targets/cuda/reduce_helper.h +++ b/include/targets/cuda/reduce_helper.h @@ -55,7 +55,8 @@ namespace quda bool reset = false; /** reset the counter post completion (required for multiple calls with the same arg instance */ using system_atomic_t = typename atomic_type::type; /** heterogeneous atomics must use lock-free atomics -> operate on scalars */ static constexpr int n_item = sizeof(T) / sizeof(system_atomic_t); /** number of words per reduction variable */ - cuda::atomic *partial; /** device atomic buffer */ + // FIXME on Hopper this could be a 128-bit type + cuda::atomic *partial; /** device atomic buffer */ cuda::atomic *result_d; /** device-mapped host atomic buffer */ cuda::atomic *result_h; /** host atomic buffer */ count_t *count; /** count array that is used to track the number of completed thread blocks at a given batch index */ @@ -78,7 +79,7 @@ namespace quda reset(reset), consumed(false) { - reducer::init(n_reduce, sizeof(*partial)); + reducer::init(n_reduce, n_item * sizeof(*partial)); // these buffers may be allocated in init, so we can't set the local copies until now partial = static_cast(reducer::get_device_buffer()); result_d = static_cast(reducer::get_mapped_buffer()); @@ -209,7 +210,15 @@ namespace quda if (arg.get_output_async_buffer()) { arg.get_output_async_buffer()[idx] = sum; } else { // write to device memory - arg.partial[idx].store(sum, cuda::std::memory_order_relaxed); + // write out the final reduced value + if (tid == 0) { + atomic_t sum_tmp[n]; + memcpy(sum_tmp, &sum, sizeof(sum)); +#pragma unroll + for (unsigned int i = 0; i < n; i++) { + arg.partial[n * idx + i].store(sum_tmp[i], cuda::std::memory_order_relaxed); + } + } } } } @@ -243,6 +252,8 @@ namespace quda template __device__ inline void reduce(Arg &arg, const Reducer &r, const T &in, const int idx) { + using atomic_t = typename atomic_type::type; + constexpr size_t n = sizeof(T) / sizeof(atomic_t); constexpr auto n_batch_block = std::min(Arg::max_n_batch_block, device::max_block_size()); using BlockReduce = BlockReduce; @@ -255,8 +266,14 @@ namespace quda if (target::thread_idx().x == 0 && target::thread_idx().y == 0) { // need to call placement new constructor since partial is not necessarily constructed - new (arg.partial + idx * target::grid_dim().x + target::block_idx().x) - cuda::atomic {aggregate}; + atomic_t aggregate_tmp[n]; + memcpy(aggregate_tmp, &aggregate, sizeof(aggregate)); + +#pragma unroll + for (int k = 0; k < n; k++) { + new (arg.partial + (idx * target::grid_dim().x + target::block_idx().x) * n + k) + cuda::atomic {aggregate_tmp[k]}; + } // increment global block counter for this reduction auto value = arg.count[idx].fetch_add(1, cuda::std::memory_order_release); @@ -272,7 +289,14 @@ namespace quda auto i = target::thread_idx().y * target::block_dim().x + target::thread_idx().x; T sum = r.init(); while (i < target::grid_dim().x) { - sum = r(sum, arg.partial[idx * target::grid_dim().x + i].load(cuda::std::memory_order_relaxed)); + atomic_t partial_tmp[n]; + T partial; +#pragma unroll + for (int k = 0; k < n; k++) { + partial_tmp[k] = arg.partial[(idx * target::grid_dim().x + i) * n + k].load(cuda::std::memory_order_relaxed); + } + memcpy(&partial, partial_tmp, sizeof(partial)); + sum = r(sum, partial); i += target::block_size<2>(); } diff --git a/include/timer.h b/include/timer.h index a98928ec39..9cbb58792a 100644 --- a/include/timer.h +++ b/include/timer.h @@ -241,14 +241,20 @@ namespace quda { the profile stack, and be popped when its destructor is called. */ struct pushProfile { - static inline double secs_dummy = 0; - static inline double gflops_dummy = 0; TimeProfile &profile; double &secs; double &gflops; + double &energy; + double &power; + double &temp; + double &clock; uint64_t flops; bool active = false; - pushProfile(TimeProfile &profile, double &secs = secs_dummy, double &gflops = gflops_dummy); + size_t monitor_start; + size_t monitor_end; + + pushProfile(TimeProfile &profile, QudaInvertParam *param = nullptr); + pushProfile(TimeProfile &profile, QudaQuarkSmearParam *param); virtual ~pushProfile(); }; diff --git a/include/tunable_nd.h b/include/tunable_nd.h index b942bc041d..f7076f9dff 100644 --- a/include/tunable_nd.h +++ b/include/tunable_nd.h @@ -169,6 +169,7 @@ namespace quda protected: mutable unsigned int vector_length_y; mutable unsigned int step_y; + mutable unsigned int step_y_bkup; bool tune_block_x; /** @@ -231,7 +232,11 @@ namespace quda */ TunableKernel2D_base(const LatticeField &field, unsigned int vector_length_y, QudaFieldLocation location = QUDA_INVALID_FIELD_LOCATION) : - TunableKernel1D_base(field, location), vector_length_y(vector_length_y), step_y(1), tune_block_x(true) + TunableKernel1D_base(field, location), + vector_length_y(vector_length_y), + step_y(1), + step_y_bkup(1), + tune_block_x(true) { } @@ -242,7 +247,11 @@ namespace quda @param[in] location Location where the calculation will take place */ TunableKernel2D_base(size_t n_items, unsigned int vector_length_y, QudaFieldLocation location) : - TunableKernel1D_base(n_items, location), vector_length_y(vector_length_y), step_y(1), tune_block_x(true) + TunableKernel1D_base(n_items, location), + vector_length_y(vector_length_y), + step_y(1), + step_y_bkup(1), + tune_block_x(true) { } @@ -317,7 +326,11 @@ namespace quda @brief Resize the autotuning step size in the y dimension @brief[in] y New step size */ - void resizeStep(int y) const { step_y = y; } + void resizeStep(int y) const + { + step_y = y; + step_y_bkup = step_y; + } }; /** @@ -410,6 +423,7 @@ namespace quda using TunableKernel2D_base::vector_length_y; mutable unsigned vector_length_z; mutable unsigned step_z; + mutable unsigned step_z_bkup; bool tune_block_y; /** @@ -478,6 +492,7 @@ namespace quda TunableKernel2D_base(field, vector_length_y, location), vector_length_z(vector_length_z), step_z(1), + step_z_bkup(1), tune_block_y(true) { } @@ -494,6 +509,7 @@ namespace quda TunableKernel2D_base(n_items, vector_length_y, location), vector_length_z(vector_length_z), step_z(1), + step_z_bkup(1), tune_block_y(true) { } @@ -581,6 +597,7 @@ namespace quda void resizeStep(int y, int z) const { step_z = z; + step_z_bkup = step_z; TunableKernel2D_base::resizeStep(y); } }; diff --git a/include/tunable_reduction.h b/include/tunable_reduction.h index c575cb8135..0ec7d4b6c2 100644 --- a/include/tunable_reduction.h +++ b/include/tunable_reduction.h @@ -144,7 +144,7 @@ namespace quda n_items(field.Volume()), block_size_y(block_size_y) { - if (commAsyncReduction()) strcat(aux, "async,"); + if (commAsyncReduction()) strcat(aux, ",async"); } /** @@ -155,7 +155,7 @@ namespace quda TunableReduction2D(size_t n_items, QudaFieldLocation location) : TunableKernel(n_items, location), n_items(n_items), block_size_y(1) { - if (commAsyncReduction()) strcat(aux, "async,"); + if (commAsyncReduction()) strcat(aux, ",async"); } /** diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 4d8639b9c9..950ac83a09 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -17,7 +17,7 @@ endif() set (QUDA_OBJS # cmake-format: sortable - monitor.cpp dirac_coarse.cpp dslash_coarse.cpp + solve.cpp monitor.cpp dirac_coarse.cpp dslash_coarse.cpp coarse_op.cpp coarsecoarse_op.cpp coarse_op_preconditioned.cpp staggered_coarse_op.cpp eig_iram.cpp eig_trlm.cpp eig_block_trlm.cpp vector_io.cpp @@ -29,6 +29,7 @@ set (QUDA_OBJS inv_multi_cg_quda.cpp inv_eigcg_quda.cpp gauge_ape.cu gauge_stout.cu gauge_hyp.cu gauge_wilson_flow.cu gauge_plaq.cu gauge_laplace.cpp gauge_observable.cpp + inv_cgnr.cpp inv_cgne.cpp inv_cg3_quda.cpp inv_ca_gcr.cpp inv_ca_cg.cpp inv_gcr_quda.cpp inv_mr_quda.cpp inv_sd_quda.cpp inv_pcg_quda.cpp inv_mre.cpp interface_quda.cpp util_quda.cpp diff --git a/lib/blas_quda.cu b/lib/blas_quda.cu index c76ae11881..1b5f552bca 100644 --- a/lib/blas_quda.cu +++ b/lib/blas_quda.cu @@ -130,7 +130,7 @@ namespace quda { return location == QUDA_CPU_FIELD_LOCATION ? false : Tunable::advanceTuneParam(param); } - long long flops() const override { return f.flops() * x.Length(); } + long long flops() const override { return f.flops() * x.Length() * x.size(); } long long bytes() const override { return (f.read.X + f.write.X) * x.Bytes() + (f.read.Y + f.write.Y) * y.Bytes() + diff --git a/lib/block_transpose.in.cu b/lib/block_transpose.in.cu index 952c6ad084..9d34fa0f40 100644 --- a/lib/block_transpose.in.cu +++ b/lib/block_transpose.in.cu @@ -37,10 +37,7 @@ namespace quda } else { strcat(aux, ",b2v"); } - strcat(aux, ",n_rhs="); - char rhs_str[8]; - i32toa(rhs_str, B.size()); - strcat(aux, rhs_str); + setRHSstring(aux, B.size()); resizeStep(1); apply(device::get_default_stream()); } @@ -116,7 +113,7 @@ namespace quda if constexpr (sizeof...(N) > 0) { launch_span_nVec(V, B, nVecs); } else { - errorQuda("nVec = %d not instantiated\n", V.Nvec()); + errorQuda("nVec = %d not instantiated", V.Nvec()); } } } @@ -138,7 +135,7 @@ namespace quda if constexpr (sizeof...(N) > 0) { launch_span_nColor(V, B, nVecs); } else { - errorQuda("nColor = %d not instantiated\n", V.Ncolor()); + errorQuda("nColor = %d not instantiated", V.Ncolor()); } } } @@ -184,7 +181,7 @@ namespace quda } else if (V.Precision() == QUDA_SINGLE_PRECISION && B[0].Precision() == QUDA_SINGLE_PRECISION) { if constexpr (is_enabled(QUDA_SINGLE_PRECISION)) block_transpose(V, B); } else { - errorQuda("Unsupported precision combination V=%d B=%d\n", V.Precision(), B[0].Precision()); + errorQuda("Unsupported precision combination V=%d B=%d", V.Precision(), B[0].Precision()); } } else { errorQuda("Multigrid has not been built"); diff --git a/lib/check_params.h b/lib/check_params.h index 5aba85a5f5..f227ee4716 100644 --- a/lib/check_params.h +++ b/lib/check_params.h @@ -196,6 +196,7 @@ void printQudaEigParam(QudaEigParam *param) { P(eig_type, QUDA_EIG_TR_LANCZOS); P(extlib_type, QUDA_EIGEN_EXTLIB); P(mem_type_ritz, QUDA_MEMORY_DEVICE); + P(compute_evals_batch_size, 4); P(ortho_block_size, 0); P(partfile, QUDA_BOOLEAN_FALSE); #else @@ -226,6 +227,7 @@ void printQudaEigParam(QudaEigParam *param) { P(eig_type, QUDA_EIG_INVALID); P(extlib_type, QUDA_EXTLIB_INVALID); P(mem_type_ritz, QUDA_MEMORY_INVALID); + P(compute_evals_batch_size, INVALID_INT); P(ortho_block_size, INVALID_INT); P(partfile, QUDA_BOOLEAN_INVALID); #endif @@ -450,8 +452,11 @@ void printQudaInvertParam(QudaInvertParam *param) { #ifndef CHECK_PARAM P(pipeline, 0); /** Whether to use a pipelined solver */ P(num_offset, 0); /**< Number of offsets in the multi-shift solver */ - P(num_src, 1); /**< Number of offsets in the multi-shift solver */ + P(num_src, 1); /**< Number of sources to solve for simultaneously */ P(overlap, 0); /**< width of domain overlaps */ +#else + if (param->num_src > QUDA_MAX_MULTI_SRC) + errorQuda("num_src %d exceeds limit of %d", param->num_src, QUDA_MAX_MULTI_SRC); #endif #ifdef INIT_PARAM @@ -460,6 +465,17 @@ void printQudaInvertParam(QudaInvertParam *param) { #else for (int d = 0; d < 4; d++) { P(split_grid[d], INVALID_INT); } /**< Grid of sub-partitions */ P(num_src_per_sub_partition, INVALID_INT); /**< Number of sources per sub-partitions */ +#ifdef CHECK_PARAM + int split_grid_size = 1; + for (int d = 0; d < 4; d++) split_grid_size *= param->split_grid[d]; + if (split_grid_size > 1) { + if (param->num_src_per_sub_partition < 1) + errorQuda("Invalid num_src_per_subpartition = %d", param->num_src_per_sub_partition); + if (param->num_src % param->num_src_per_sub_partition != 0) + errorQuda("num_src %d not compatible with num_src_per_sub_partition %d", param->num_src, + param->num_src_per_sub_partition); + } +#endif #endif #ifdef INIT_PARAM @@ -659,10 +675,18 @@ void printQudaInvertParam(QudaInvertParam *param) { P(iter, 0); P(gflops, 0.0); P(secs, 0.0); + P(energy, 0.0); + P(power, 0.0); + P(temp, 0.0); + P(clock, 0.0); #elif defined(PRINT_PARAM) P(iter, INVALID_INT); P(gflops, INVALID_DOUBLE); P(secs, INVALID_DOUBLE); + P(energy, INVALID_DOUBLE); + P(power, INVALID_DOUBLE); + P(temp, INVALID_DOUBLE); + P(clock, INVALID_DOUBLE); #endif @@ -952,12 +976,14 @@ void printQudaMultigridParam(QudaMultigridParam *param) { #endif #ifdef INIT_PARAM - if (i + std::enable_if_t + launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const qudaStream_t &stream, Tunable &tunable) + { + if (query_max) return 2; + switch (tp.aux.x) { + // clang-format off + case 0: launch_compute_uv_kernel< 64, 64, 32, 8, 16>(tp, arg, min_threads, stream, tunable); break; + case 1: launch_compute_uv_kernel< 64, 64, 32, 8, 32>(tp, arg, min_threads, stream, tunable); break; + case 2: launch_compute_uv_kernel< 64, 64, 32, 16, 16>(tp, arg, min_threads, stream, tunable); break; + // clang-format on + default: + errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin, + Arg::fineColor, Arg::coarseColor); + } + return -1; + } + // note --- currently unused, may be revisited in the future template std::enable_if_t @@ -326,6 +344,7 @@ namespace quda || (Arg::fineColor == 24 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 24 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 32 && Arg::coarseColor == 32 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) + || (Arg::fineColor == 32 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 64 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 64 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 96 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2)), @@ -554,6 +573,23 @@ namespace quda return -1; } + template + std::enable_if_t + launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const qudaStream_t &stream, Tunable &tunable) + { + if (query_max) return 3; + // clang-format off + switch (tp.aux.x) { + case 0: launch_compute_vuv_kernel< 64, 64, 32, 8, 8>(tp, arg, min_threads, stream, tunable); break; + case 1: launch_compute_vuv_kernel< 64, 64, 32, 8, 16>(tp, arg, min_threads, stream, tunable); break; + case 2: launch_compute_vuv_kernel< 64, 64, 32, 16, 8>(tp, arg, min_threads, stream, tunable); break; + case 3: launch_compute_vuv_kernel< 64, 64, 32, 32, 4>(tp, arg, min_threads, stream, tunable); break; + default: errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin, Arg::fineColor, Arg::coarseColor); + } + // clang-format on + return -1; + } + // note --- currently unused, may be revisited in the future template std::enable_if_t @@ -631,6 +667,7 @@ namespace quda || (Arg::fineColor == 24 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 24 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 32 && Arg::coarseColor == 32 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) + || (Arg::fineColor == 32 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 64 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 64 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 96 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2)), diff --git a/lib/color_spinor_util.in.cu b/lib/color_spinor_util.in.cu index 0b9355d4d1..ef370eda18 100644 --- a/lib/color_spinor_util.in.cu +++ b/lib/color_spinor_util.in.cu @@ -423,4 +423,18 @@ namespace quda { resize(v, new_size, param); } + void create_alias(cvector_ref &alias, cvector_ref &v, + const ColorSpinorParam ¶m) + { + if (alias.size() != v.size()) errorQuda("sets differ in size (%lu != %lu)", alias.size(), v.size()); + for (auto i = 0u; i < v.size(); i++) alias[i] = const_cast(v[i]).create_alias(param); + } + + void create_alias(std::vector &alias, cvector_ref &v, + const ColorSpinorParam ¶m) + { + alias.resize(v.size()); + create_alias(cvector_ref(alias), v, param); + } + } // namespace quda diff --git a/lib/communicator_stack.cpp b/lib/communicator_stack.cpp index 386a4a910d..ef5db59a0a 100644 --- a/lib/communicator_stack.cpp +++ b/lib/communicator_stack.cpp @@ -29,6 +29,7 @@ namespace quda auto search = communicator_stack.find(default_comm_key); if (search == communicator_stack.end()) { fprintf(getOutputFile(), "Default communicator can't be found\n"); + fflush(getOutputFile()); comm_abort(1); } return search->second; @@ -39,6 +40,7 @@ namespace quda auto search = communicator_stack.find(current_key); if (search == communicator_stack.end()) { fprintf(getOutputFile(), "Current communicator can't be found\n"); + fflush(getOutputFile()); comm_abort(1); } return search->second; diff --git a/lib/copy_color_spinor_mg.in.hpp b/lib/copy_color_spinor_mg.in.hpp index 92b96b4c1a..343e0a18ff 100644 --- a/lib/copy_color_spinor_mg.in.hpp +++ b/lib/copy_color_spinor_mg.in.hpp @@ -17,7 +17,8 @@ namespace quda { template - class CopySpinor : TunableKernel1D { + class CopySpinor : TunableKernel3D + { ColorSpinorField &out; const ColorSpinorField ∈ FloatOut *Out; @@ -27,12 +28,8 @@ namespace quda { unsigned int minThreads() const { return in.VolumeCB(); } public: - CopySpinor(ColorSpinorField &out, const ColorSpinorField &in, QudaFieldLocation location, FloatOut* Out, FloatIn* In) : - TunableKernel1D(in, location), - out(out), - in(in), - Out(Out), - In(In) + CopySpinor(ColorSpinorField &out, const ColorSpinorField &in, QudaFieldLocation location, FloatOut *Out, FloatIn *In) : + TunableKernel3D(in, in.Nspin(), in.Ncolor(), location), out(out), in(in), Out(Out), In(In) { apply(device::get_default_stream()); } @@ -44,7 +41,6 @@ namespace quda { launch(tp, stream, CopyArg(out, in, Out, In)); } - long long flops() const { return 0; } long long bytes() const { return in.Bytes() + out.Bytes(); } }; diff --git a/lib/dirac_clover.cpp b/lib/dirac_clover.cpp index a7f8702e79..8fdf896552 100644 --- a/lib/dirac_clover.cpp +++ b/lib/dirac_clover.cpp @@ -82,10 +82,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracClover::reconstruct(cvector_ref &, cvector_ref &, @@ -225,30 +223,25 @@ namespace quda { { // we desire solution to preconditioned system if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + // we desire solution to full system - auto tmp = getFieldTmp(b[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - if (symmetric) { - // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) - src[i] = x[i][other_parity].create_alias(); - CloverInv(src[i], b[i][other_parity], other_parity); - DiracWilson::DslashXpay(tmp, src[i], this_parity, b[i][this_parity], kappa); - CloverInv(src[i], tmp, this_parity); - sol[i] = x[i][this_parity].create_alias(); - } else { - // src = b_e + k D_eo A_oo^-1 b_o - src[i] = x[i][other_parity].create_alias(); - CloverInv(tmp, b[i][other_parity], other_parity); // safe even when tmp = b.odd - DiracWilson::DslashXpay(src[i], tmp, this_parity, b[this_parity], kappa); - sol[i] = x[i][this_parity].create_alias(); - } + auto tmp = getFieldTmp(x.Even()); + if (symmetric) { + // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) + CloverInv(src, b(other_parity), other_parity); + DiracWilson::DslashXpay(tmp, src, this_parity, b(this_parity), kappa); + CloverInv(src, tmp, this_parity); + } else { + // src = b_e + k D_eo A_oo^-1 b_o + CloverInv(tmp, b(other_parity), other_parity); // safe even when tmp = b.odd + DiracWilson::DslashXpay(src, tmp, this_parity, b(this_parity), kappa); } } @@ -257,13 +250,11 @@ namespace quda { { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; - auto tmp = getFieldTmp(b[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // x_o = A_oo^-1 (b_o + k D_oe x_e) - DiracWilson::DslashXpay(tmp, x[i][this_parity], other_parity, b[i][other_parity], kappa); - CloverInv(x[i][other_parity], tmp, other_parity); - } + auto tmp = getFieldTmp(x.Even()); + checkFullSpinor(x, b); + // x_o = A_oo^-1 (b_o + k D_oe x_e) + DiracWilson::DslashXpay(tmp, x(this_parity), other_parity, b(other_parity), kappa); + CloverInv(x(other_parity), tmp, other_parity); } void DiracCloverPC::createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double, double mu, diff --git a/lib/dirac_clover_hasenbusch_twist.cpp b/lib/dirac_clover_hasenbusch_twist.cpp index fe08ac765a..1e7c614526 100644 --- a/lib/dirac_clover_hasenbusch_twist.cpp +++ b/lib/dirac_clover_hasenbusch_twist.cpp @@ -29,14 +29,14 @@ namespace quda void DiracCloverHasenbuschTwist::M(cvector_ref &out, cvector_ref &in) const { if (symmetric) { - ApplyWilsonCloverHasenbuschTwist(out[this_parity], in[other_parity], *gauge, *clover, -kappa, mu, in[this_parity], + ApplyWilsonCloverHasenbuschTwist(out(this_parity), in(other_parity), *gauge, *clover, -kappa, mu, in(this_parity), this_parity, dagger, commDim.data, profile); - ApplyWilsonClover(out[other_parity], in[this_parity], *gauge, *clover, -kappa, in[other_parity], other_parity, + ApplyWilsonClover(out(other_parity), in(this_parity), *gauge, *clover, -kappa, in(other_parity), other_parity, dagger, commDim.data, profile); } else { - ApplyWilsonClover(out[other_parity], in[this_parity], *gauge, *clover, -kappa, in[other_parity], other_parity, + ApplyWilsonClover(out(other_parity), in(this_parity), *gauge, *clover, -kappa, in(other_parity), other_parity, dagger, commDim.data, profile); - ApplyTwistedClover(out[this_parity], in[other_parity], *gauge, *clover, -kappa, mu, in[this_parity], this_parity, + ApplyTwistedClover(out(this_parity), in(other_parity), *gauge, *clover, -kappa, mu, in(this_parity), this_parity, dagger, commDim.data, profile); } } @@ -115,10 +115,6 @@ namespace quda double kappa2 = -kappa * kappa; auto tmp = getFieldTmp(out); - bool symmetric = (matpcType == QUDA_MATPC_EVEN_EVEN || matpcType == QUDA_MATPC_ODD_ODD) ? true : false; - int odd_bit = (matpcType == QUDA_MATPC_ODD_ODD || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) ? 1 : 0; - QudaParity parity[2] = {static_cast((1 + odd_bit) % 2), static_cast((0 + odd_bit) % 2)}; - if (!symmetric) { // No need to change order of calls for dagger // because the asymmetric operator is actually symmetric @@ -126,27 +122,27 @@ namespace quda // the pieces in Dslash and DslashXPay respect the dagger // DiracCloverHasenbuschTwistPC::Dslash applies A^{-1}Dslash - Dslash(tmp, in, parity[0]); + Dslash(tmp, in, other_parity); // applies (A + imu*g5 - kappa^2 D)- - ApplyTwistedClover(out, tmp, *gauge, *clover, kappa2, mu, in, parity[1], dagger, commDim.data, profile); + ApplyTwistedClover(out, tmp, *gauge, *clover, kappa2, mu, in, this_parity, dagger, commDim.data, profile); } else if (!dagger) { // symmetric preconditioning // We need two cases because M = 1-ADAD and M^\dag = 1-D^\dag A D^dag A // where A is actually a clover inverse. // This is the non-dag case: AD - Dslash(tmp, in, parity[0]); + Dslash(tmp, in, other_parity); // Then x + AD (AD) - DslashXpayTwistClovInv(out, tmp, parity[1], in, kappa2, mu); + DslashXpayTwistClovInv(out, tmp, this_parity, in, kappa2, mu); } else { // symmetric preconditioning, dagger // This is the dagger: 1 - DADA // i) Apply A - CloverInv(out, in, parity[1]); + CloverInv(out, in, this_parity); // ii) Apply A D => ADA - Dslash(tmp, out, parity[0]); + Dslash(tmp, out, other_parity); // iii) Apply x + D(ADA) - DslashXpayTwistNoClovInv(out, tmp, parity[1], in, kappa2, mu); + DslashXpayTwistNoClovInv(out, tmp, this_parity, in, kappa2, mu); } } diff --git a/lib/dirac_coarse.cpp b/lib/dirac_coarse.cpp index c138b95158..62f4f7639e 100644 --- a/lib/dirac_coarse.cpp +++ b/lib/dirac_coarse.cpp @@ -121,10 +121,6 @@ namespace quda { { } - DiracCoarse::~DiracCoarse() - { - } - void DiracCoarse::createY(bool gpu, bool mapped) const { int ndim = transfer->Vectors().Ndim(); @@ -455,10 +451,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracCoarse::reconstruct(cvector_ref &, cvector_ref &, @@ -513,19 +507,17 @@ namespace quda { /* do nothing */ } - DiracCoarsePC::~DiracCoarsePC() { } - void DiracCoarsePC::Dslash(cvector_ref &out, cvector_ref &in, QudaParity parity) const { QudaFieldLocation location = checkLocation(out[0], in[0]); initializeLazy(location); - if ( location == QUDA_CUDA_FIELD_LOCATION) { + if (location == QUDA_CUDA_FIELD_LOCATION) { auto Y = apply_mma(out, dslash_use_mma) ? Yhat_aos_d : Yhat_d; auto X = apply_mma(out, dslash_use_mma) ? X_aos_d : X_d; ApplyCoarse(out, in, in, *Y, *X, kappa, parity, true, false, dagger, commDim.data, halo_precision, dslash_use_mma); - } else if ( location == QUDA_CPU_FIELD_LOCATION ) { + } else if (location == QUDA_CPU_FIELD_LOCATION) { ApplyCoarse(out, in, in, *Yhat_h, *X_h, kappa, parity, true, false, dagger, commDim.data, halo_precision, dslash_use_mma); } @@ -536,7 +528,7 @@ namespace quda { { // FIXME emulated for now Dslash(out, in, parity); - for (auto i = 0u; i < x.size(); i++) blas::xpay(x[i], k, out[i]); + blas::xpay(x, k, out); } void DiracCoarsePC::M(cvector_ref &out, cvector_ref &in) const @@ -553,14 +545,14 @@ namespace quda { // DiracCoarse::DslashXpay applies (A - D) // FIXME this ignores the -1 DiracCoarse::Dslash(out, tmp, QUDA_EVEN_PARITY); Clover(tmp, in, QUDA_EVEN_PARITY); - for (auto i = 0u; i < in.size(); i++) blas::xpay(tmp[i], -1.0, out[i]); + blas::xpay(tmp, -1.0, out); } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { // DiracCoarsePC::Dslash applies A^{-1}Dslash Dslash(tmp, in, QUDA_EVEN_PARITY); // DiracCoarse::DslashXpay applies (A - D) // FIXME this ignores the -1 DiracCoarse::Dslash(out, tmp, QUDA_ODD_PARITY); Clover(tmp, in, QUDA_ODD_PARITY); - for (auto i = 0u; i < in.size(); i++) blas::xpay(tmp[i], -1.0, out[i]); + blas::xpay(tmp, -1.0, out); } else if (matpcType == QUDA_MATPC_EVEN_EVEN) { Dslash(tmp, in, QUDA_ODD_PARITY); DslashXpay(out, tmp, QUDA_EVEN_PARITY, in, -1.0); @@ -585,68 +577,56 @@ namespace quda { { // we desire solution to preconditioned system if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } - auto tmp = getFieldTmp(b[0].Even()); + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + auto tmp = getFieldTmp(x.Even()); // we desire solution to full system - for (auto i = 0u; i < b.size(); i++) { - - if (symmetric) { - // src = A_ee^-1 (b_e - D_eo A_oo^-1 b_o) - src[i] = x[i][other_parity].create_alias(); + if (symmetric) { + // src = A_ee^-1 (b_e - D_eo A_oo^-1 b_o) #if 0 - CloverInv(src[i], b[other_parity], other_parity); - DiracCoarse::Dslash(tmp, src[i], this_parity); - blas::xpay(b[i][this_parity], -1.0, tmp); - CloverInv(src[i], tmp, this_parity); + CloverInv(src, b(other_parity), other_parity); + DiracCoarse::Dslash(tmp, src, this_parity); + blas::xpay(b(this_parity), -1.0, tmp); + CloverInv(src, tmp, this_parity); +#else + // src = A_ee^{-1} b_e - (A_ee^{-1} D_eo) A_oo^{-1} b_o + CloverInv(src, b(other_parity), other_parity); + Dslash(tmp, src, this_parity); + CloverInv(src, b(this_parity), this_parity); + blas::axpy(-1.0, tmp, src); #endif - // src = A_ee^{-1} b_e - (A_ee^{-1} D_eo) A_oo^{-1} b_o - CloverInv(src[i], b[i][other_parity], other_parity); - Dslash(tmp, src[i], this_parity); - CloverInv(src[i], b[i][this_parity], this_parity); - blas::axpy(-1.0, tmp, src[i]); - - sol[i] = x[i][this_parity].create_alias(); - } else { - // src = b_e - D_eo A_oo^-1 b_o - src[i] = x[i][other_parity].create_alias(); - CloverInv(tmp, b[i][other_parity], other_parity); - DiracCoarse::Dslash(src[i], tmp, this_parity); - blas::xpay(b[i][this_parity], -1.0, src[i]); - sol[i] = x[i][this_parity].create_alias(); - } + } else { + // src = b_e - D_eo A_oo^-1 b_o + CloverInv(tmp, b(other_parity), other_parity); + DiracCoarse::Dslash(src, tmp, this_parity); + blas::xpay(b(this_parity), -1.0, src); } } void DiracCoarsePC::reconstruct(cvector_ref &x, cvector_ref &b, const QudaSolutionType solType) const { - if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - return; - } - - auto tmp = getFieldTmp(b[0].Even()); - - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); + if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; + checkFullSpinor(x, b); + auto tmp = getFieldTmp(x.Even()); #if 0 - // x_o = A_oo^-1 (b_o - D_oe x_e) - DiracCoarse::Dslash(tmp, x.Even(), QUDA_ODD_PARITY); - blas::xpay(b.Odd(), -1.0, tmp); - CloverInv(x.Odd(), tmp, QUDA_ODD_PARITY); + // x_o = A_oo^-1 (b_o - D_oe x_e) + DiracCoarse::Dslash(tmp, x(this_parity), other_parity); + blas::xpay(b(other_parity), -1.0, tmp); + CloverInv(x(other_parity), tmp, other_parity); +#else + // x_o = A_oo^{-1} b_o - (A_oo^{-1} D_oe) x_e + Dslash(tmp, x(this_parity), other_parity); + CloverInv(x(other_parity), b(other_parity), other_parity); + blas::axpy(-1.0, tmp, x(other_parity)); #endif - // x_o = A_oo^{-1} b_o - (A_oo^{-1} D_oe) x_e - Dslash(tmp, x[i][this_parity], other_parity); - CloverInv(x[i][other_parity], b[i][other_parity], other_parity); - blas::axpy(-1.0, tmp, x[i][other_parity]); - } } //Make the coarse operator one level down. For the preconditioned diff --git a/lib/dirac_domain_wall.cpp b/lib/dirac_domain_wall.cpp index a7ad1581b4..6e5886116d 100644 --- a/lib/dirac_domain_wall.cpp +++ b/lib/dirac_domain_wall.cpp @@ -84,10 +84,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracDomainWall::reconstruct(cvector_ref &, cvector_ref &, @@ -152,20 +150,17 @@ namespace quda { const QudaSolutionType solType) const { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + // we desire solution to full system - for (auto i = 0u; i < b.size(); i++) { - // src = b_e + k D_eo b_o - DslashXpay(x[i][other_parity], b[i][other_parity], this_parity, b[this_parity], kappa5); - src[i] = x[i][other_parity].create_alias(); - sol[i] = x[i][this_parity].create_alias(); - } + // src = b_e + k D_eo b_o + DslashXpay(x(other_parity), b(other_parity), this_parity, b(this_parity), kappa5); } void DiracDomainWallPC::reconstruct(cvector_ref &x, cvector_ref &b, @@ -174,11 +169,9 @@ namespace quda { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; // create full solution - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // x_o = b_o + k D_oe x_e - DslashXpay(x[i][other_parity], x[i][this_parity], other_parity, b[i][other_parity], kappa5); - } + checkFullSpinor(x, b); + // x_o = b_o + k D_oe x_e + DslashXpay(x(other_parity), x(this_parity), other_parity, b(other_parity), kappa5); } } // namespace quda diff --git a/lib/dirac_domain_wall_4d.cpp b/lib/dirac_domain_wall_4d.cpp index bfa5e980e4..fe5b4e943e 100644 --- a/lib/dirac_domain_wall_4d.cpp +++ b/lib/dirac_domain_wall_4d.cpp @@ -86,10 +86,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracDomainWall4D::reconstruct(cvector_ref &, cvector_ref &, @@ -171,30 +169,25 @@ namespace quda { const QudaSolutionType solType) const { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + // we desire solution to full system - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - if (symmetric) { - // src = M5^-1 (b_e + k D4_eo*M5^-1 b_o) - src[i] = x[i][other_parity].create_alias(); - M5inv(src[i], b[i][other_parity]); - Dslash4Xpay(tmp, src[i], this_parity, b[i][this_parity], kappa5); - M5inv(src[i], tmp); - sol[i] = x[i][this_parity].create_alias(); - } else { - // src = b_e + k D4_eo*M5^-1 b_o - src[i] = x[i][other_parity].create_alias(); - M5inv(tmp, b[i][other_parity]); - Dslash4Xpay(src[i], tmp, this_parity, b[i][this_parity], kappa5); - sol[i] = x[i][this_parity].create_alias(); - } + auto tmp = getFieldTmp(x.Even()); + if (symmetric) { + // src = M5^-1 (b_e + k D4_eo*M5^-1 b_o) + M5inv(src, b(other_parity)); + Dslash4Xpay(tmp, src, this_parity, b(this_parity), kappa5); + M5inv(src, tmp); + } else { + // src = b_e + k D4_eo*M5^-1 b_o + M5inv(tmp, b(other_parity)); + Dslash4Xpay(src, tmp, this_parity, b(this_parity), kappa5); } } @@ -204,13 +197,11 @@ namespace quda { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; // create full solution - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // x_o = M5^-1 (b_o + k D4_oe x_e) - Dslash4Xpay(tmp, x[i][this_parity], other_parity, b[i][other_parity], kappa5); - M5inv(x[i][other_parity], tmp); - } + auto tmp = getFieldTmp(x.Even()); + checkFullSpinor(x, b); + // x_o = M5^-1 (b_o + k D4_oe x_e) + Dslash4Xpay(tmp, x(this_parity), other_parity, b(other_parity), kappa5); + M5inv(x(other_parity), tmp); } } // end namespace quda diff --git a/lib/dirac_improved_staggered.cpp b/lib/dirac_improved_staggered.cpp index b5e2b0d326..82f6537fe2 100644 --- a/lib/dirac_improved_staggered.cpp +++ b/lib/dirac_improved_staggered.cpp @@ -93,10 +93,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracImprovedStaggered::reconstruct(cvector_ref &, cvector_ref &, @@ -212,18 +210,17 @@ namespace quda { return; } - for (auto i = 0u; i < b.size(); i++) { - // we desire solution to full system. - // With the convention given in DiracStaggered::M(), - // the source is src = 2m b_e + D_eo b_o - // But remember, DslashXpay actually applies - // -D_eo. Flip the sign on 2m to compensate, and - // then flip the overall sign. - src[i] = x[i][other_parity].create_alias(); - DslashXpay(src[i], b[i][other_parity], this_parity, b[i][this_parity], -2.0 * mass); - blas::ax(-1.0, src[i]); - sol[i] = x[i][this_parity].create_alias(); - } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + + // we desire solution to full system. + // With the convention given in DiracStaggered::M(), + // the source is src = 2m b_e + D_eo b_o + // But remember, DslashXpay actually applies + // -D_eo. Flip the sign on 2m to compensate, and + // then flip the overall sign. + DslashXpay(src, b(other_parity), this_parity, b(this_parity), -2.0 * mass); + blas::ax(-1.0, src); } void DiracImprovedStaggeredPC::reconstruct(cvector_ref &x, cvector_ref &b, @@ -233,19 +230,17 @@ namespace quda { return; } - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - - // create full solution - // With the convention given in DiracStaggered::M(), - // the reconstruct is x_o = 1/(2m) (b_o + D_oe x_e) - // But remember: DslashXpay actually applies -D_oe, - // so just like above we need to flip the sign - // on b_o. We then correct this by applying an additional - // minus sign when we rescale by 2m. - DslashXpay(x[i][other_parity], x[i][this_parity], other_parity, b[i][other_parity], -1.0); - blas::ax(-0.5 / mass, x[i][other_parity]); - } + checkFullSpinor(x, b); + + // create full solution + // With the convention given in DiracStaggered::M(), + // the reconstruct is x_o = 1/(2m) (b_o + D_oe x_e) + // But remember: DslashXpay actually applies -D_oe, + // so just like above we need to flip the sign + // on b_o. We then correct this by applying an additional + // minus sign when we rescale by 2m. + DslashXpay(x(other_parity), x(this_parity), other_parity, b(other_parity), -1.0); + blas::ax(-0.5 / mass, x(other_parity)); } void DiracImprovedStaggeredPC::createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double, double mass, diff --git a/lib/dirac_mobius.cpp b/lib/dirac_mobius.cpp index 7e59dded58..c5567f435c 100644 --- a/lib/dirac_mobius.cpp +++ b/lib/dirac_mobius.cpp @@ -136,10 +136,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracMobius::reconstruct(cvector_ref &, cvector_ref &, @@ -364,49 +362,42 @@ namespace quda { const QudaSolutionType solType) const { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + // we desire solution to full system - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - if (symmetric) { - // src = D5^-1 (b_e + k D4_eo * D4pre * D5^-1 b_o) - src[i] = x[i][other_parity].create_alias(); - M5inv(tmp, b[i][other_parity]); - Dslash4pre(src[i], tmp); - Dslash4Xpay(tmp, src[i], this_parity, b[i][this_parity], 1.0); - M5inv(src[i], tmp); - sol[i] = x[i][this_parity].create_alias(); - } else { - // src = b_e + k D4_eo * D4pre * D5inv b_o - src[i] = x[i][other_parity].create_alias(); - M5inv(src[i], b[i][other_parity]); - Dslash4pre(tmp, src[i]); - Dslash4Xpay(src[i], tmp, this_parity, b[i][this_parity], 1.0); - sol[i] = x[i][this_parity].create_alias(); - } + auto tmp = getFieldTmp(x.Even()); + if (symmetric) { + // src = D5^-1 (b_e + k D4_eo * D4pre * D5^-1 b_o) + M5inv(tmp, b(other_parity)); + Dslash4pre(src, tmp); + Dslash4Xpay(tmp, src, this_parity, b(this_parity), 1.0); + M5inv(src, tmp); + } else { + // src = b_e + k D4_eo * D4pre * D5inv b_o + M5inv(src, b(other_parity)); + Dslash4pre(tmp, src); + Dslash4Xpay(src, tmp, this_parity, b(this_parity), 1.0); } } void DiracMobiusPC::reconstruct(cvector_ref &x, cvector_ref &b, const QudaSolutionType solType) const { - if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { return; } + if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; // create full solution - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // psi_o = M5^-1 (b_o + k_b D4_oe D4pre x_e) - Dslash4pre(x[i][other_parity], x[i][this_parity]); - Dslash4Xpay(tmp, x[i][other_parity], other_parity, b[i][other_parity], 1.0); - M5inv(x[i][other_parity], tmp); - } + auto tmp = getFieldTmp(x.Even()); + checkFullSpinor(x, b); + // psi_o = M5^-1 (b_o + k_b D4_oe D4pre x_e) + Dslash4pre(x(other_parity), x(this_parity)); + Dslash4Xpay(tmp, x(other_parity), other_parity, b(other_parity), 1.0); + M5inv(x(other_parity), tmp); } void DiracMobiusPC::MdagMLocal(cvector_ref &out, cvector_ref &in) const @@ -592,10 +583,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracMobiusEofa::reconstruct(cvector_ref &, cvector_ref &, @@ -676,32 +665,27 @@ namespace quda { const QudaSolutionType solType) const { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + // we desire solution to full system - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - if (symmetric) { - // src = D5^-1 (b_e + k D4_eo * D4pre * D5^-1 b_o) - src[i] = x[i][other_parity].create_alias(); - m5inv_eofa(tmp, b[i][other_parity]); - Dslash4pre(src[i], tmp); - Dslash4Xpay(tmp, src[i], this_parity, b[i][this_parity], 1.0); - m5inv_eofa(src[i], tmp); - sol[i] = x[i][this_parity].create_alias(); - } else if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { - // src = b_e + k D4_eo * D4pre * D5inv b_o - src[i] = x[i][other_parity].create_alias(); - m5inv_eofa(src[i], b[i][other_parity]); - Dslash4pre(tmp, src[i]); - Dslash4Xpay(src[i], tmp, this_parity, b[i][this_parity], 1.0); - sol[i] = x[i][this_parity].create_alias(); - } + auto tmp = getFieldTmp(x.Even()); + if (symmetric) { + // src = D5^-1 (b_e + k D4_eo * D4pre * D5^-1 b_o) + m5inv_eofa(tmp, b(other_parity)); + Dslash4pre(src, tmp); + Dslash4Xpay(tmp, src, this_parity, b(this_parity), 1.0); + m5inv_eofa(src, tmp); + } else if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { + // src = b_e + k D4_eo * D4pre * D5inv b_o + m5inv_eofa(src, b(other_parity)); + Dslash4pre(tmp, src); + Dslash4Xpay(src, tmp, this_parity, b(this_parity), 1.0); } } @@ -711,14 +695,12 @@ namespace quda { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; // create full solution - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // psi_o = M5^-1 (b_o + k_b D4_oe D4pre x_e) - Dslash4pre(x[i][other_parity], x[i][this_parity]); - Dslash4Xpay(tmp, x[i][other_parity], other_parity, b[i][other_parity], 1.0); - m5inv_eofa(x[i][other_parity], tmp); - } + auto tmp = getFieldTmp(x.Even()); + checkFullSpinor(x, b); + // psi_o = M5^-1 (b_o + k_b D4_oe D4pre x_e) + Dslash4pre(x(other_parity), x(this_parity)); + Dslash4Xpay(tmp, x(other_parity), other_parity, b(other_parity), 1.0); + m5inv_eofa(x(other_parity), tmp); } void DiracMobiusEofaPC::MdagM(cvector_ref &out, cvector_ref &in) const diff --git a/lib/dirac_staggered.cpp b/lib/dirac_staggered.cpp index c850733ce6..7a034bf3b8 100644 --- a/lib/dirac_staggered.cpp +++ b/lib/dirac_staggered.cpp @@ -88,10 +88,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracStaggered::reconstruct(cvector_ref &, cvector_ref &, @@ -208,26 +206,23 @@ namespace quda { const QudaSolutionType solType) const { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - // we desire solution to preconditioned system - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + // we desire solution to preconditioned system + create_alias(src, b); + create_alias(sol, x); return; } - for (auto i = 0u; i < b.size(); i++) { - // we desire solution to full system. - // With the convention given in DiracStaggered::M(), - // the source is src = 2m b_e + D_eo b_o - // But remember, DslashXpay actually applies - // -D_eo. Flip the sign on 2m to compensate, and - // then flip the overall sign. - src[i] = x[i][other_parity].create_alias(); - DslashXpay(src[i], b[i][other_parity], this_parity, b[i][this_parity], -2.0 * mass); - blas::ax(-1.0, src[i]); - sol[i] = x[i][this_parity].create_alias(); - } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + + // we desire solution to full system. + // With the convention given in DiracStaggered::M(), + // the source is src = 2m b_e + D_eo b_o + // But remember, DslashXpay actually applies + // -D_eo. Flip the sign on 2m to compensate, and + // then flip the overall sign. + DslashXpay(src, b(other_parity), this_parity, b(this_parity), -2.0 * mass); + blas::ax(-1.0, src); } void DiracStaggeredPC::reconstruct(cvector_ref &x, cvector_ref &b, @@ -237,19 +232,17 @@ namespace quda { return; } - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - - // create full solution - // With the convention given in DiracStaggered::M(), - // the reconstruct is x_o = 1/(2m) (b_o + D_oe x_e) - // But remember: DslashXpay actually applies -D_oe, - // so just like above we need to flip the sign - // on b_o. We then correct this by applying an additional - // minus sign when we rescale by 2m. - DslashXpay(x[i][other_parity], x[i][this_parity], other_parity, b[i][other_parity], -1.0); - blas::ax(-0.5 / mass, x[i][other_parity]); - } + checkFullSpinor(x, b); + + // create full solution + // With the convention given in DiracStaggered::M(), + // the reconstruct is x_o = 1/(2m) (b_o + D_oe x_e) + // But remember: DslashXpay actually applies -D_oe, + // so just like above we need to flip the sign + // on b_o. We then correct this by applying an additional + // minus sign when we rescale by 2m. + DslashXpay(x(other_parity), x(this_parity), other_parity, b(other_parity), -1.0); + blas::ax(-0.5 / mass, x(other_parity)); } void DiracStaggeredPC::createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double, double mass, double, diff --git a/lib/dirac_twisted_clover.cpp b/lib/dirac_twisted_clover.cpp index cb6b46bdb1..ca630da80c 100644 --- a/lib/dirac_twisted_clover.cpp +++ b/lib/dirac_twisted_clover.cpp @@ -125,10 +125,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracTwistedClover::reconstruct(cvector_ref &, cvector_ref &, @@ -295,31 +293,27 @@ namespace quda { { // we desire solution to preconditioned system if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } // we desire solution to full system - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - src[i] = x[i][other_parity].create_alias(); - sol[i] = x[i][this_parity].create_alias(); - - TwistCloverInv(!symmetric ? static_cast(tmp) : src[i], b[i][other_parity], other_parity); + auto tmp = getFieldTmp(x.Even()); + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); - if (symmetric) { - // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) - WilsonDslashXpay(tmp, src[i], this_parity, b[i][this_parity], kappa); - } else { - // src = b_e + k D_eo A_oo^-1 b_o - WilsonDslashXpay(src[i], tmp, this_parity, b[i][this_parity], kappa); - } + TwistCloverInv(!symmetric ? tmp : src, b(other_parity), other_parity); - if (symmetric) TwistCloverInv(src[i], tmp, this_parity); + if (symmetric) { + // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) + WilsonDslashXpay(tmp, src, this_parity, b(this_parity), kappa); + } else { + // src = b_e + k D_eo A_oo^-1 b_o + WilsonDslashXpay(src, tmp, this_parity, b(this_parity), kappa); } + + if (symmetric) TwistCloverInv(src, tmp, this_parity); } void DiracTwistedCloverPC::reconstruct(cvector_ref &x, cvector_ref &b, @@ -327,13 +321,11 @@ namespace quda { { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // x_o = A_oo^-1 (b_o + k D_oe x_e) - WilsonDslashXpay(tmp, x[i][this_parity], other_parity, b[i][other_parity], kappa); - TwistCloverInv(x[i][other_parity], tmp, other_parity); - } + auto tmp = getFieldTmp(x.Even()); + checkFullSpinor(x, b); + // x_o = A_oo^-1 (b_o + k D_oe x_e) + WilsonDslashXpay(tmp, x(this_parity), other_parity, b(other_parity), kappa); + TwistCloverInv(x(other_parity), tmp, other_parity); } void DiracTwistedCloverPC::createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double, diff --git a/lib/dirac_twisted_mass.cpp b/lib/dirac_twisted_mass.cpp index a6b13706fb..c7909145ce 100644 --- a/lib/dirac_twisted_mass.cpp +++ b/lib/dirac_twisted_mass.cpp @@ -108,10 +108,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracTwistedMass::reconstruct(cvector_ref &, cvector_ref &, @@ -165,14 +163,14 @@ namespace quda { if (in.TwistFlavor() != out.TwistFlavor()) errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor()); if (in.TwistFlavor() == QUDA_TWIST_NO || in.TwistFlavor() == QUDA_TWIST_INVALID) - errorQuda("Twist flavor not set %d\n", in.TwistFlavor()); + errorQuda("Twist flavor not set %d", in.TwistFlavor()); if (in.TwistFlavor() == QUDA_TWIST_SINGLET) { double a = -2.0 * kappa * mu; // for inverse twist double b = 1.0 / (1.0 + a * a); bool asymmetric - = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; + = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; ApplyTwistedMassPreconditioned(out, in, *gauge, b, a, false, in, parity, dagger, asymmetric, commDim.data, profile); } else {//TWIST doublet : double a = 2.0 * kappa * mu; @@ -180,7 +178,7 @@ namespace quda { double c = 1.0 / (1.0 + a * a - b * b); bool asymmetric - = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; + = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; ApplyNdegTwistedMassPreconditioned(out, in, *gauge, c, -2.0 * mu * kappa, 2.0 * kappa * epsilon, false, in, parity, dagger, asymmetric, commDim.data, profile); } @@ -200,9 +198,9 @@ namespace quda { if(in.TwistFlavor() == QUDA_TWIST_SINGLET) { double a = -2.0 * kappa * mu; // for inverse twist double b = k / (1.0 + a * a); - // asymmetric should never be true here since we never need to apply 1 + k * A^{-1} D^\dagger + // asymmetric should never be false here since we never need to apply 1 + k * A^{-1} D^\dagger bool asymmetric - = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; + = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; ApplyTwistedMassPreconditioned(out, in, *gauge, b, a, true, x, parity, dagger, asymmetric, commDim.data, profile); } else {//TWIST_DOUBLET: double a = 2.0 * kappa * mu; @@ -210,7 +208,7 @@ namespace quda { double c = 1.0 / (1.0 + a * a - b * b); bool asymmetric - = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; + = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; ApplyNdegTwistedMassPreconditioned(out, in, *gauge, k * c, -2 * mu * kappa, 2 * kappa * epsilon, true, x, parity, dagger, asymmetric, commDim.data, profile); } @@ -244,56 +242,50 @@ namespace quda { { // we desire solution to preconditioned system if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } // we desire solution to full system - auto tmp = getFieldTmp(x[0].Even()); - bool symmetric = (matpcType == QUDA_MATPC_EVEN_EVEN || matpcType == QUDA_MATPC_ODD_ODD) ? true : false; + auto tmp = getFieldTmp(x.Even()); + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); - for (auto i = 0u; i < b.size(); i++) { - src[i] = x[i][other_parity].create_alias(); - sol[i] = x[i][this_parity].create_alias(); + TwistInv(symmetric ? src : tmp, b(other_parity)); - TwistInv(symmetric ? src[i] : static_cast(tmp), b[i][other_parity]); + if (b.TwistFlavor() == QUDA_TWIST_SINGLET) { - if (b.TwistFlavor() == QUDA_TWIST_SINGLET) { - - if (symmetric) { - // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) - DiracWilson::DslashXpay(tmp, src[i], this_parity, b[i][this_parity], kappa); - } else { - // src = b_e + k D_eo A_oo^-1 b_o - DiracWilson::DslashXpay(src[i], tmp, this_parity, b[i][this_parity], kappa); - } + if (symmetric) { + // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) + DiracWilson::DslashXpay(tmp, src, this_parity, b(this_parity), kappa); + } else { + // src = b_e + k D_eo A_oo^-1 b_o + DiracWilson::DslashXpay(src, tmp, this_parity, b(this_parity), kappa); + } - } else { // doublet: + } else { // doublet: - // repurpose the preconditioned dslash as a vectorized operator: 1+kappa*D - double mu_ = mu; - mu = 0.0; - double epsilon_ = epsilon; - epsilon = 0.0; + // repurpose the preconditioned dslash as a vectorized operator: 1+kappa*D + double mu_ = mu; + mu = 0.0; + double epsilon_ = epsilon; + epsilon = 0.0; - if (symmetric) { - // src = A_ee^-1(b_e + k D_eo A_oo^-1 b_o) - DslashXpay(tmp, src[i], this_parity, b[i][this_parity], kappa); - } else { - // src = b_e + k D_eo A_oo^-1 b_o - DslashXpay(src[i], tmp, this_parity, b[i][this_parity], kappa); - } + if (symmetric) { + // src = A_ee^-1(b_e + k D_eo A_oo^-1 b_o) + DslashXpay(tmp, src, this_parity, b(this_parity), kappa); + } else { + // src = b_e + k D_eo A_oo^-1 b_o + DslashXpay(src, tmp, this_parity, b(this_parity), kappa); + } - mu = mu_; - epsilon = epsilon_; + mu = mu_; + epsilon = epsilon_; - } // end of doublet + } // end of doublet - if (symmetric) TwistInv(src[i], tmp); - } + if (symmetric) TwistInv(src, tmp); } void DiracTwistedMassPC::reconstruct(cvector_ref &x, cvector_ref &b, @@ -301,29 +293,27 @@ namespace quda { { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - - // create full solution - if (b.TwistFlavor() == QUDA_TWIST_SINGLET) { - // x_o = A_oo^-1 (b_o + k D_oe x_e) - DiracWilson::DslashXpay(tmp, x[i][this_parity], other_parity, b[i][other_parity], kappa); - } else { // twist doublet: - double mu_ = mu; - mu = 0.0; - double epsilon_ = epsilon; - epsilon = 0.0; - - // x_o = A_oo^-1 (b_o + k D_oe x_e) - DslashXpay(tmp, x[i][this_parity], other_parity, b[i][other_parity], kappa); - - mu = mu_; - epsilon = epsilon_; - } + auto tmp = getFieldTmp(x.Even()); + checkFullSpinor(x, b); + + // create full solution + if (b.TwistFlavor() == QUDA_TWIST_SINGLET) { + // x_o = A_oo^-1 (b_o + k D_oe x_e) + DiracWilson::DslashXpay(tmp, x(this_parity), other_parity, b(other_parity), kappa); + } else { // twist doublet: + double mu_ = mu; + mu = 0.0; + double epsilon_ = epsilon; + epsilon = 0.0; - TwistInv(x[i][other_parity], tmp); + // x_o = A_oo^-1 (b_o + k D_oe x_e) + DslashXpay(tmp, x(this_parity), other_parity, b(other_parity), kappa); + + mu = mu_; + epsilon = epsilon_; } + + TwistInv(x(other_parity), tmp); } void DiracTwistedMassPC::createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double, diff --git a/lib/dirac_wilson.cpp b/lib/dirac_wilson.cpp index 30565cfba3..c9570e57ee 100644 --- a/lib/dirac_wilson.cpp +++ b/lib/dirac_wilson.cpp @@ -12,8 +12,6 @@ namespace quda { // hack (for DW and TM operators) DiracWilson::DiracWilson(const DiracParam ¶m, const int) : Dirac(param) { } - DiracWilson::~DiracWilson() { } - DiracWilson& DiracWilson::operator=(const DiracWilson &dirac) { if (&dirac != this) { Dirac::operator=(dirac); } @@ -75,10 +73,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracWilson::reconstruct(cvector_ref &, cvector_ref &, @@ -105,11 +101,6 @@ namespace quda { DiracWilsonPC::DiracWilsonPC(const DiracWilsonPC &dirac) : DiracWilson(dirac) { } - DiracWilsonPC::~DiracWilsonPC() - { - - } - DiracWilsonPC& DiracWilsonPC::operator=(const DiracWilsonPC &dirac) { if (&dirac != this) { @@ -145,22 +136,19 @@ namespace quda { cvector_ref &x, cvector_ref &b, const QudaSolutionType solType) const { - if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + // we desire solution to preconditioned system + create_alias(src, b); + create_alias(sol, x); return; } // we desire solution to full system - for (auto i = 0u; i < b.size(); i++) { - // src = b_e + k D_eo b_o - DslashXpay(x[i][other_parity], b[i][other_parity], this_parity, b[i][this_parity], kappa); - src[i] = x[i][other_parity].create_alias(); - sol[i] = x[i][this_parity].create_alias(); - } + // src = b_e + k D_eo b_o + DslashXpay(x(other_parity), b(other_parity), this_parity, b(this_parity), kappa); + + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); } void DiracWilsonPC::reconstruct(cvector_ref &x, cvector_ref &b, @@ -169,11 +157,9 @@ namespace quda { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; // create full solution - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // x_o = b_o + k D_oe x_e - DslashXpay(x[i][other_parity], x[i][this_parity], other_parity, b[i][other_parity], kappa); - } + checkFullSpinor(x, b); + // x_o = b_o + k D_oe x_e + DslashXpay(x(other_parity), x(this_parity), other_parity, b(other_parity), kappa); } } // namespace quda diff --git a/lib/dslash_clover_helper.cu b/lib/dslash_clover_helper.cu index 0754191850..2fc342bfb4 100644 --- a/lib/dslash_clover_helper.cu +++ b/lib/dslash_clover_helper.cu @@ -38,7 +38,7 @@ namespace quda { long long flops() const { return in.size() * in.Volume() * 504ll; } - long long bytes() const { return in.size() * (out.Bytes() + in.Bytes() + clover.Bytes() / (3 - in.SiteSubset())); } + long long bytes() const { return out.Bytes() + in.Bytes() + clover.Bytes() / (3 - in.SiteSubset()); } }; //Apply the clover matrix field to a colorspinor field @@ -123,10 +123,10 @@ namespace quda { long long flops() const { return in.size() * (inverse ? 1056ll : 552ll) * in.Volume(); } long long bytes() const { - long long rtn = out.Bytes() + in.Bytes() + clover.Bytes() / (3 - in.SiteSubset()); + long long rtn = out.Bytes() + in.Bytes() + in.size() * clover.Bytes() / (3 - in.SiteSubset()); if (twist == QUDA_TWIST_GAMMA5_INVERSE && !clover::dynamic_inverse()) - rtn += clover.Bytes() / (3 - in.SiteSubset()); - return in.size() * rtn; + rtn += in.size() * clover.Bytes() / (3 - in.SiteSubset()); + return rtn; } }; diff --git a/lib/dslash_pack2.cu b/lib/dslash_pack2.cu index 1db387df98..043ea65abc 100644 --- a/lib/dslash_pack2.cu +++ b/lib/dslash_pack2.cu @@ -53,8 +53,6 @@ namespace quda template class Pack : TunableKernel3D { - -protected: void **ghost; const ColorSpinorField &halo; cvector_ref ∈ @@ -166,8 +164,11 @@ protected: case Device: strcat(aux, ",device-device"); break; case Host: strcat(aux, comm_peer2peer_enabled_global() ? ",host-device" : ",host-host"); break; case Shmem: strcat(aux, ",shmem"); break; - default: errorQuda("Unknown pack target location %d\n", location); + default: errorQuda("Unknown pack target location %d", location); } +#ifdef STRIPED + strcat(aux, ",striped"); +#endif } public: @@ -340,7 +341,7 @@ public: #endif } else { - errorQuda("Unsupported nSpin = %d\n", in.Nspin()); + errorQuda("Unsupported nSpin = %d", in.Nspin()); } } diff --git a/lib/dslash_policy.hpp b/lib/dslash_policy.hpp index dce21f1be3..f7d024085f 100644 --- a/lib/dslash_policy.hpp +++ b/lib/dslash_policy.hpp @@ -1859,8 +1859,12 @@ namespace quda } if (comm_nvshmem_enabled()) { - enable_policy(QudaDslashPolicy::QUDA_SHMEM_UBER_PACKINTRA_DSLASH); - enable_policy(QudaDslashPolicy::QUDA_SHMEM_UBER_PACKFULL_DSLASH); + if (in.size() <= 16) { + // FIXME until uber dslash gets fine-grained + // synchronization, we cannot use it with large RHS + enable_policy(QudaDslashPolicy::QUDA_SHMEM_UBER_PACKINTRA_DSLASH); + enable_policy(QudaDslashPolicy::QUDA_SHMEM_UBER_PACKFULL_DSLASH); + } enable_policy(QudaDslashPolicy::QUDA_SHMEM_PACKINTRA_DSLASH); enable_policy(QudaDslashPolicy::QUDA_SHMEM_PACKFULL_DSLASH); } diff --git a/lib/eigensolve_quda.cpp b/lib/eigensolve_quda.cpp index 47e2b1b60e..8130897cbb 100644 --- a/lib/eigensolve_quda.cpp +++ b/lib/eigensolve_quda.cpp @@ -142,11 +142,11 @@ namespace quda { // Use 0th vector to extract meta data for the RNG. RNG rng(kSpace[0], 1234); + // If the spinor contains valid initial data from the user preserve it, else populate with rands. + // We use `!isfinite || norm == 0` instead of `isnormal` because subnormal vectors are still numerically legal + auto norm = blas::norm2({kSpace.begin(), kSpace.begin() + block_size}); for (int b = 0; b < block_size; b++) { - // If the spinor contains valid initial data from the user preserve it, else populate with rands. - // We use `!isfinite || norm == 0` instead of `isnormal` because subnormal vectors are still numerically legal - auto norm = blas::norm2(kSpace[b]); - if (!std::isfinite(norm) || norm == 0.0) { spinorNoise(kSpace[b], rng, QUDA_NOISE_UNIFORM); } + if (!std::isfinite(norm[b]) || norm[b] == 0.0) { spinorNoise(kSpace[b], rng, QUDA_NOISE_UNIFORM); } } bool orthed = false; @@ -496,13 +496,14 @@ namespace quda { logQuda(QUDA_SUMMARIZE, "Computing SVD of M\n"); + auto batch_size = eig_param->compute_evals_batch_size; int n_conv = eig_param->n_conv; if (evecs.size() < (unsigned int)(2 * n_conv)) errorQuda("Incorrect deflation space sized %d passed to computeSVD, expected %d", (int)(evecs.size()), 2 * n_conv); - std::vector sigma_tmp(n_conv); - - for (int i = 0; i < n_conv; i++) { + for (int i = 0; i < n_conv; i += batch_size) { + auto lower = i; + auto upper = i + batch_size < n_conv ? i + batch_size : n_conv; // This function assumes that you have computed the eigenvectors // of MdagM(MMdag), ie, the right(left) SVD of M. The ith eigen vector in the @@ -515,22 +516,27 @@ namespace quda //-------------------------------------------------------------------------- // Lambda already contains the square root of the eigenvalue of the norm op. - Complex lambda = evals[i]; // M*Rev_i = M*Rsv_i = sigma_i Lsv_i - mat.Expose()->M(evecs[n_conv + i], evecs[i]); + mat.Expose()->M({evecs.begin() + n_conv + lower, evecs.begin() + n_conv + upper}, + {evecs.begin() + lower, evecs.begin() + upper}); // sigma_i = sqrt(sigma_i (Lsv_i)^dag * sigma_i * Lsv_i ) - sigma_tmp[i] = sqrt(blas::norm2(evecs[n_conv + i])); + auto sigma = blas::norm2({evecs.begin() + n_conv + lower, evecs.begin() + n_conv + upper}); + decltype(sigma) sigma_inv(sigma.size()); + for (auto j = 0u; j < sigma.size(); j++) { + sigma[j] = sqrt(sigma[j]); + sigma_inv[j] = 1.0 / sigma[j]; + } // Normalise the Lsv: sigma_i Lsv_i -> Lsv_i - blas::ax(1.0 / sigma_tmp[i], evecs[n_conv + i]); + blas::ax(sigma_inv, {evecs.begin() + n_conv + lower, evecs.begin() + n_conv + upper}); - logQuda(QUDA_SUMMARIZE, "Sval[%04d] = %+.16e sigma - sqrt(|lambda|) = %+.16e\n", i, sigma_tmp[i], - sigma_tmp[i] - sqrt(abs(lambda.real()))); - - evals[i] = sigma_tmp[i]; - //-------------------------------------------------------------------------- + for (auto j = 0u; j < sigma.size(); j++) { + logQuda(QUDA_SUMMARIZE, "Sval[%04d] = %+.16e sigma - sqrt(|lambda|) = %+.16e\n", i + j, sigma[j], + sigma[j] - sqrt(abs(evals[i + j].real()))); + evals[i + j] = sigma[j]; + } } } @@ -563,42 +569,56 @@ namespace quda // 2. Perform block caxpy // A_i -> (\sigma_i)^{-1} * A_i // vec_defl = Sum_i (R_i)^{-1} * A_i - if (!accumulate) for (auto &x : sol) blas::zero(x); - for (int i = 0; i < n_defl; i++) s[i] /= evals[i].real(); + for (auto j = 0u; j < src.size(); j++) + for (int i = 0; i < n_defl; i++) { s[i * src.size() + j] /= evals[i].real(); } + // 3. Accumulate sum vec_defl = Sum_i V_i * (L_i)^{-1} * A_i + if (!accumulate) blas::zero(sol); blas::block::caxpy(s, {evecs.begin(), evecs.begin() + n_defl}, {sol.begin(), sol.end()}); } void EigenSolver::computeEvals(std::vector &evecs, std::vector &evals, int size) { - if (size > (int)evecs.size()) + auto batch_size = eig_param->compute_evals_batch_size; + + if (size > static_cast(evecs.size())) errorQuda("Requesting %d eigenvectors with only storage allocated for %lu", size, evecs.size()); + + // allocate space if needed for computing the evals + if (size + batch_size > static_cast(evecs.size())) resize(evecs, size + batch_size, QUDA_NULL_FIELD_CREATE); + // we make sure that we have enough space for eigenvalues // required for coarse-grid deflated solver used from within tmLQCD or PLEGMA with // `preserve_deflation` enabled if (size > (int)evals.size()) evals.resize(size); - ColorSpinorParam csParamClone(evecs[0]); - csParamClone.create = QUDA_NULL_FIELD_CREATE; - ColorSpinorField temp(csParamClone); + for (int i = 0; i < size; i += batch_size) { + auto lower = i; + auto upper = i + batch_size < size ? i + batch_size : size; + + auto temp = {evecs.begin() + size, evecs.begin() + size + upper - lower}; - for (int i = 0; i < size; i++) { // r = A * v_i - mat(temp, evecs[i]); + mat(temp, {evecs.begin() + lower, evecs.begin() + upper}); // lambda_i = v_i^dag A v_i / (v_i^dag * v_i) - evals[i] = blas::cDotProduct(evecs[i], temp) / sqrt(blas::norm2(evecs[i])); + auto vtAv = blas::cDotProduct({evecs.begin() + lower, evecs.begin() + upper}, temp); + auto v2 = blas::norm2({evecs.begin() + lower, evecs.begin() + upper}); + for (auto j = 0u; j < v2.size(); j++) evals[i + j] = vtAv[j] / sqrt(v2[j]); // Measure ||lambda_i*v_i - A*v_i|| Complex n_unit(-1.0, 0.0); - blas::caxpby(evals[i], evecs[i], n_unit, temp); - residua[i] = sqrt(blas::norm2(temp)); - // eig_param->invert_param->true_res_offset[i] = residua[i]; + auto res = blas::caxpbyNorm({evals.begin() + lower, evals.begin() + upper}, + {evecs.begin() + lower, evecs.begin() + upper}, n_unit, temp); + for (auto j = 0u; j < v2.size(); j++) residua[i + j] = sqrt(res[j]); // If size = n_conv, this routine is called post sort - if (size == n_conv) - logQuda(QUDA_SUMMARIZE, "Eval[%04d] = (%+.16e,%+.16e) ||%+.16e|| Residual = %+.16e\n", i, evals[i].real(), - evals[i].imag(), abs(evals[i]), residua[i]); + if (size == n_conv) { + for (int j = lower; j < upper; j++) { + logQuda(QUDA_SUMMARIZE, "Eval[%04d] = (%+.16e,%+.16e) ||%+.16e|| Residual = %+.16e\n", j, evals[j].real(), + evals[j].imag(), abs(evals[j]), residua[j]); + } + } } } @@ -624,11 +644,11 @@ namespace quda blas::block::cDotProduct(s, {evecs.begin(), evecs.begin() + n_defl}, {src.begin(), src.end()}); // 2. Perform block caxpy: V_i * (L_i)^{-1} * A_i - for (int i = 0; i < n_defl; i++) { s[i] /= evals[i].real(); } + for (auto j = 0u; j < src.size(); j++) + for (int i = 0; i < n_defl; i++) { s[i * src.size() + j] /= evals[i].real(); } // 3. Accumulate sum vec_defl = Sum_i V_i * (L_i)^{-1} * A_i - if (!accumulate) for (auto &x : sol) blas::zero(x); - + if (!accumulate) blas::zero(sol); blas::block::caxpy(s, {evecs.begin(), evecs.begin() + n_defl}, {sol.begin(), sol.end()}); } diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index c012affdcd..5575d36a54 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -108,12 +108,12 @@ CloverField *cloverEigensolver = nullptr; GaugeField momResident; GaugeField *extendedGaugeResident = nullptr; -std::vector solutionResident; +namespace quda +{ -// vector of spinors used for forecasting solutions in HMC -#define QUDA_MAX_CHRONO 12 -// each entry is one p -std::vector> chronoResident(QUDA_MAX_CHRONO); + std::vector solutionResident; + +} // Mapped memory buffer used to hold unitarization failures static int *num_failures_h = nullptr; @@ -297,6 +297,15 @@ static void profilerStop(const char *f) { namespace quda { void printLaunchTimer(); + + void flushChrono(int i = -1); + + void massRescale(cvector_ref &b, QudaInvertParam ¶m, bool for_multishift); + + void distanceReweight(cvector_ref &b, QudaInvertParam ¶m, bool inverse); + + void solve(const std::vector &hp_x, const std::vector &hp_b, QudaInvertParam ¶m, + const GaugeField &u); } void setVerbosityQuda(QudaVerbosity verbosity, const char prefix[], FILE *outfile) @@ -1322,13 +1331,7 @@ void freeCloverQuda(void) cloverPrecise = nullptr; } -void flushChronoQuda(int i) -{ - if (i >= QUDA_MAX_CHRONO) - errorQuda("Requested chrono index %d is outside of max %d\n", i, QUDA_MAX_CHRONO); - - chronoResident[i].clear(); -} +void flushChronoQuda(int i) { flushChrono(i); } void endQuda(void) { @@ -1340,7 +1343,7 @@ void endQuda(void) freeGaugeQuda(); freeCloverQuda(); - for (int i = 0; i < QUDA_MAX_CHRONO; i++) flushChronoQuda(i); + flushChrono(); solutionResident.clear(); momResident = GaugeField(); @@ -1372,11 +1375,11 @@ void endQuda(void) assertAllMemFree(); device::destroy(); - - comm_finalize(); - comms_initialized = false; } + comm_finalize(); + comms_initialized = false; + profileInit2End.TPSTOP(QUDA_PROFILE_TOTAL); // print out the profile information of the lifetime of the library @@ -1697,115 +1700,11 @@ namespace quda { dEig = Dirac::create(diracEigParam); } - void massRescale(ColorSpinorField &b, QudaInvertParam ¶m, bool for_multishift) - { - double kappa5 = (0.5/(5.0 + param.m5)); - double kappa = (param.dslash_type == QUDA_DOMAIN_WALL_DSLASH || param.dslash_type == QUDA_DOMAIN_WALL_4D_DSLASH - || param.dslash_type == QUDA_MOBIUS_DWF_DSLASH || param.dslash_type == QUDA_MOBIUS_DWF_EOFA_DSLASH) ? - kappa5 : - param.kappa; - - logQuda(QUDA_DEBUG_VERBOSE, "Mass rescale: Kappa is: %g\n", kappa); - logQuda(QUDA_DEBUG_VERBOSE, "Mass rescale: mass normalization: %d\n", param.mass_normalization); - logQuda(QUDA_DEBUG_VERBOSE, "Mass rescale: norm of source in = %g\n", blas::norm2(b)); - - // staggered dslash uses mass normalization internally - if (param.dslash_type == QUDA_ASQTAD_DSLASH || param.dslash_type == QUDA_STAGGERED_DSLASH) { - switch (param.solution_type) { - case QUDA_MAT_SOLUTION: - case QUDA_MATPC_SOLUTION: - if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(2.0*param.mass, b); - break; - case QUDA_MATDAG_MAT_SOLUTION: - case QUDA_MATPCDAG_MATPC_SOLUTION: - if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(4.0*param.mass*param.mass, b); - break; - default: - errorQuda("Not implemented"); - } - return; - } - - // multiply the source to compensate for normalization of the Dirac operator, if necessary - // you are responsible for restoring what's in param.offset - switch (param.solution_type) { - case QUDA_MAT_SOLUTION: - if (param.mass_normalization == QUDA_MASS_NORMALIZATION || - param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { - blas::ax(2.0*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 2.0 * kappa; - } - break; - case QUDA_MATDAG_MAT_SOLUTION: - if (param.mass_normalization == QUDA_MASS_NORMALIZATION || - param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { - blas::ax(4.0*kappa*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; - } - break; - case QUDA_MATPC_SOLUTION: - if (param.mass_normalization == QUDA_MASS_NORMALIZATION) { - blas::ax(4.0*kappa*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; - } else if (param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { - blas::ax(2.0*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 2.0 * kappa; - } - break; - case QUDA_MATPCDAG_MATPC_SOLUTION: - if (param.mass_normalization == QUDA_MASS_NORMALIZATION) { - blas::ax(16.0*std::pow(kappa,4), b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 16.0 * std::pow(kappa, 4); - } else if (param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { - blas::ax(4.0*kappa*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; - } - break; - default: - errorQuda("Solution type %d not supported", param.solution_type); - } - - logQuda(QUDA_DEBUG_VERBOSE, "Mass rescale: norm of source out = %g\n", blas::norm2(b)); - } -} - -void distanceReweight(ColorSpinorField &b, QudaInvertParam ¶m, bool inverse) -{ - // Force the alpha0 to be positive. - // A negative alpha0 matches something like Eq.(12) in arXiv:1006.4028. - // Disable the negative situation as QUDA already has multigrid for light quarks. - const double alpha0 = abs(param.distance_pc_alpha0); - const int t0 = param.distance_pc_t0; - if (alpha0 != 0.0 && t0 >= 0) { - if (param.dslash_type != QUDA_WILSON_DSLASH && param.dslash_type != QUDA_CLOVER_WILSON_DSLASH) { - errorQuda("Only Wilson and Wilson-clover dslash support distance preconditioning, but get dslash_type %d\n", - param.dslash_type); - } - if (param.inv_type == QUDA_MG_INVERTER) { - errorQuda("Multigrid solver doesn't support distance preconditioning\n"); - } - if (param.cuda_prec != QUDA_DOUBLE_PRECISION || param.cuda_prec_sloppy != QUDA_DOUBLE_PRECISION) { - warningQuda( - "Using single or half (sloppy) precision in distance preconditioning sometimes makes the solver diverge"); - } - - if (inverse) - spinorDistanceReweight(b, -alpha0, t0); - else - spinorDistanceReweight(b, alpha0, t0); - } } void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity parity) { - auto profile = pushProfile(profileDslash, inv_param->secs, inv_param->gflops); - + auto profile = pushProfile(profileDslash, inv_param); const auto &gauge = (inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise; if ((!gaugePrecise && inv_param->dslash_type != QUDA_ASQTAD_DSLASH) @@ -1834,7 +1733,7 @@ void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity in = in_h; - profileDslash.TPSTART(QUDA_PROFILE_COMPUTE); + getProfile().TPSTART(QUDA_PROFILE_COMPUTE); logQuda(QUDA_DEBUG_VERBOSE, "In CPU %e CUDA %e\n", blas::norm2(in_h), blas::norm2(in)); @@ -1866,7 +1765,7 @@ void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity } else { dirac->Dslash(out, in, parity); // apply the operator } - profileDslash.TPSTOP(QUDA_PROFILE_COMPUTE); + getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); distanceReweight(out, *inv_param, false); @@ -1881,8 +1780,7 @@ void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity void shiftQuda(void *h_out, void *h_in, int dir, int sym, QudaInvertParam *param) { - auto profile = pushProfile(profileCovDev, param->secs, param->gflops); - + auto profile = pushProfile(profileCovDev, param); const auto &gauge = *gaugePrecise; //(inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise; QudaInvertParam &inv_param = *param; @@ -1951,8 +1849,7 @@ void shiftQuda(void *h_out, void *h_in, int dir, int sym, QudaInvertParam *param void spinTasteQuda(void *h_out, void *h_in, int spin_, int taste, QudaInvertParam *param) { - auto profile = pushProfile(profileCovDev, param->secs, param->gflops); - + auto profile = pushProfile(profileCovDev, param); const auto &gauge = *gaugePrecise; //(inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise; QudaInvertParam &inv_param = *param; @@ -2262,10 +2159,9 @@ void spinTasteQuda(void *h_out, void *h_in, int spin_, int taste, QudaInvertPara void covDevQuda(void *h_out, void *h_in, int dir, QudaInvertParam *param) { - auto profile = pushProfile(profileCovDev, param->secs, param->gflops); + auto profile = pushProfile(profileCovDev, param); QudaInvertParam &inv_param = *param; - const auto &gauge = *gaugePrecise; //(inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise; inv_param.solution_type = QUDA_MAT_SOLUTION; @@ -2629,7 +2525,7 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam // This will define the operator to be eigensolved. QudaInvertParam *inv_param = eig_param->invert_param; - auto profile = pushProfile(profileEigensolve, inv_param->secs, inv_param->gflops); + auto profile = pushProfile(profileEigensolve, inv_param); // QUDA can employ even-odd preconditioning to an operator. // For the eigensolver the solution type must match @@ -2790,9 +2686,6 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam delete dEig; popVerbosity(); - - // cache is written out even if a long benchmarking job gets interrupted - saveTuneCache(); } multigrid_solver::multigrid_solver(QudaMultigridParam &mg_param) @@ -2870,21 +2763,16 @@ multigrid_solver::multigrid_solver(QudaMultigridParam &mg_param) mg = new MG(*mgParam); mgParam->updateInvertParam(*param); - - // cache is written out even if a long benchmarking job gets interrupted - saveTuneCache(); } void *newMultigridQuda(QudaMultigridParam *mg_param) { profilerStart(__func__); - auto profile = pushProfile(profileInvert, mg_param->secs, mg_param->gflops); + auto profile = pushProfile(profileInvert, mg_param->invert_param); pushVerbosity(mg_param->invert_param->verbosity); auto *mg = new multigrid_solver(*mg_param); - saveTuneCache(); - popVerbosity(); profilerStop(__func__); return static_cast(mg); @@ -2897,7 +2785,7 @@ void destroyMultigridQuda(void *mg) { void updateMultigridQuda(void *mg_, QudaMultigridParam *mg_param) { profilerStart(__func__); - auto profile = pushProfile(profileInvert, mg_param->secs, mg_param->gflops); + auto profile = pushProfile(profileInvert, mg_param->invert_param); pushVerbosity(mg_param->invert_param->verbosity); profileInvert.TPSTART(QUDA_PROFILE_PREAMBLE); @@ -2997,9 +2885,6 @@ void updateMultigridQuda(void *mg_, QudaMultigridParam *mg_param) setOutputPrefix(""); - // cache is written out even if a long benchmarking job gets interrupted - saveTuneCache(); - profileInvert.TPSTOP(QUDA_PROFILE_PREAMBLE); popVerbosity(); @@ -3009,7 +2894,7 @@ void updateMultigridQuda(void *mg_, QudaMultigridParam *mg_param) void dumpMultigridQuda(void *mg_, QudaMultigridParam *mg_param) { profilerStart(__func__); - auto profile = pushProfile(profileInvert, mg_param->secs, mg_param->gflops); + auto profile = pushProfile(profileInvert, mg_param->invert_param); pushVerbosity(mg_param->invert_param->verbosity); auto *mg = static_cast(mg_); @@ -3084,7 +2969,7 @@ deflated_solver::deflated_solver(QudaEigParam &eig_param, TimeProfile &profile) } void* newDeflationQuda(QudaEigParam *eig_param) { - auto profile = pushProfile(profileInvert, eig_param->secs, eig_param->gflops); + auto profile = pushProfile(profileInvert, eig_param->invert_param); auto *defl = new deflated_solver(*eig_param, profileInvert); saveProfile(__func__); flushProfile(); @@ -3097,7 +2982,7 @@ void destroyDeflationQuda(void *df) { void invertQuda(void *hp_x, void *hp_b, QudaInvertParam *param) { - auto profile = pushProfile(profileInvert, param->secs, param->gflops); + auto profile = pushProfile(profileInvert, param); profilerStart(__func__); if (!initialized) errorQuda("QUDA not initialized"); @@ -3110,259 +2995,10 @@ void invertQuda(void *hp_x, void *hp_b, QudaInvertParam *param) // check the gauge fields have been created GaugeField *cudaGauge = checkGauge(param); - // It was probably a bad design decision to encode whether the system is even/odd preconditioned (PC) in - // solve_type and solution_type, rather than in separate members of QudaInvertParam. We're stuck with it - // for now, though, so here we factorize everything for convenience. - - bool pc_solution = (param->solution_type == QUDA_MATPC_SOLUTION) || - (param->solution_type == QUDA_MATPCDAG_MATPC_SOLUTION); - bool pc_solve = (param->solve_type == QUDA_DIRECT_PC_SOLVE) || - (param->solve_type == QUDA_NORMOP_PC_SOLVE) || (param->solve_type == QUDA_NORMERR_PC_SOLVE); - bool mat_solution = (param->solution_type == QUDA_MAT_SOLUTION) || - (param->solution_type == QUDA_MATPC_SOLUTION); - bool direct_solve = (param->solve_type == QUDA_DIRECT_SOLVE) || - (param->solve_type == QUDA_DIRECT_PC_SOLVE); - bool norm_error_solve = (param->solve_type == QUDA_NORMERR_SOLVE) || - (param->solve_type == QUDA_NORMERR_PC_SOLVE); - - param->iter = 0; - - Dirac *d = nullptr; - Dirac *dSloppy = nullptr; - Dirac *dPre = nullptr; - Dirac *dEig = nullptr; - - // Create the dirac operator and operators for sloppy, precondition, - // and an eigensolver - createDiracWithEig(d, dSloppy, dPre, dEig, *param, pc_solve); - - Dirac &dirac = *d; - Dirac &diracSloppy = *dSloppy; - Dirac &diracPre = *dPre; - Dirac &diracEig = *dEig; - - // wrap CPU host side pointers - ColorSpinorParam cpuParam(hp_b, *param, cudaGauge->X(), pc_solution, param->input_location); - ColorSpinorField h_b(cpuParam); - - cpuParam.v = hp_x; - cpuParam.location = param->output_location; - ColorSpinorField h_x(cpuParam); - - // download source - ColorSpinorParam cudaParam(cpuParam, *param, QUDA_CUDA_FIELD_LOCATION); - cudaParam.create = QUDA_COPY_FIELD_CREATE; - cudaParam.field = &h_b; - ColorSpinorField b(cudaParam); - - // now check if we need to invalidate the solutionResident vectors - ColorSpinorField x; - if (param->use_resident_solution == 1) { - for (auto &v : solutionResident) { - if (b.Precision() != v.Precision() || b.SiteSubset() != v.SiteSubset()) { - solutionResident.clear(); - break; - } - } - - if (!solutionResident.size()) { - cudaParam.create = QUDA_NULL_FIELD_CREATE; - solutionResident = std::vector(1, cudaParam); - } - x = solutionResident[0].create_alias(cudaParam); - } else { - cudaParam.create = QUDA_NULL_FIELD_CREATE; - x = ColorSpinorField(cudaParam); - } - - if (param->use_init_guess == QUDA_USE_INIT_GUESS_YES && !param->chrono_use_resident) { // download initial guess - // initial guess only supported for single-pass solvers - if ((param->solution_type == QUDA_MATDAG_MAT_SOLUTION || param->solution_type == QUDA_MATPCDAG_MATPC_SOLUTION) && - (param->solve_type == QUDA_DIRECT_SOLVE || param->solve_type == QUDA_DIRECT_PC_SOLVE)) { - errorQuda("Initial guess not supported for two-pass solver"); - } - - x = h_x; // solution - } else { // zero initial guess - blas::zero(x); - } - - // if we're doing a managed memory MG solve and prefetching is - // enabled, prefetch all the Dirac matrices. There's probably - // a better place to put this... - if (param->inv_type_precondition == QUDA_MG_INVERTER) { - dirac.prefetch(QUDA_CUDA_FIELD_LOCATION); - diracSloppy.prefetch(QUDA_CUDA_FIELD_LOCATION); - diracPre.prefetch(QUDA_CUDA_FIELD_LOCATION); - } - - profileInvert.TPSTART(QUDA_PROFILE_PREAMBLE); - - double nb = blas::norm2(b); - if (nb==0.0) errorQuda("Source has zero norm"); - logQuda(QUDA_VERBOSE, "Source: %g\n", nb); - if (param->use_init_guess == QUDA_USE_INIT_GUESS_YES) logQuda(QUDA_VERBOSE, "Initial guess: %g\n", blas::norm2(x)); - - // rescale the source and solution vectors to help prevent the onset of underflow - if (param->solver_normalization == QUDA_SOURCE_NORMALIZATION) { - blas::ax(1.0 / sqrt(nb), b); - blas::ax(1.0 / sqrt(nb), x); - } - - massRescale(b, *param, false); - distanceReweight(b, *param, true); - - ColorSpinorField in; - ColorSpinorField out; - dirac.prepare(out, in, x, b, param->solution_type); - - logQuda(QUDA_VERBOSE, "Prepared source = %g\n", blas::norm2(in)); - logQuda(QUDA_VERBOSE, "Prepared solution = %g\n", blas::norm2(out)); - - // solution_type specifies *what* system is to be solved. - // solve_type specifies *how* the system is to be solved. - // - // We have the following four cases (plus preconditioned variants): - // - // solution_type solve_type Effect - // ------------- ---------- ------ - // MAT DIRECT Solve Ax=b - // MATDAG_MAT DIRECT Solve A^dag y = b, followed by Ax=y - // MAT NORMOP Solve (A^dag A) x = (A^dag b) - // MATDAG_MAT NORMOP Solve (A^dag A) x = b - // MAT NORMERR Solve (A A^dag) y = b, then x = A^dag y - // - // We generally require that the solution_type and solve_type - // preconditioning match. As an exception, the unpreconditioned MAT - // solution_type may be used with any solve_type, including - // DIRECT_PC and NORMOP_PC. In these cases, preparation of the - // preconditioned source and reconstruction of the full solution are - // taken care of by Dirac::prepare() and Dirac::reconstruct(), - // respectively. - - profileInvert.TPSTOP(QUDA_PROFILE_PREAMBLE); - - if (mat_solution && !direct_solve && !norm_error_solve) { // prepare source: b' = A^dag b - ColorSpinorField tmp(in); - dirac.Mdag(in, tmp); - } else if (!mat_solution && direct_solve) { // perform the first of two solves: A^dag y = b - DiracMdag m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); - SolverParam solverParam(*param); - Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); - (*solve)(out, in); - blas::copy(in, out); - delete solve; - solverParam.updateInvertParam(*param); - } - - if (direct_solve) { - DiracM m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); - SolverParam solverParam(*param); - - // chronological forecasting - if (param->chrono_use_resident && chronoResident[param->chrono_index].size() > 0) { - bool hermitian = false; - auto &mChrono = param->chrono_precision == param->cuda_prec ? m : mSloppy; - chronoExtrapolate(out, in, chronoResident[param->chrono_index], mChrono, hermitian); - } - - Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); - (*solve)(out, in); - delete solve; - solverParam.updateInvertParam(*param); - } else if (!norm_error_solve) { - DiracMdagM m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); - SolverParam solverParam(*param); - - // chronological forecasting - if (param->chrono_use_resident && chronoResident[param->chrono_index].size() > 0) { - bool hermitian = true; - auto &mChrono = param->chrono_precision == param->cuda_prec ? m : mSloppy; - chronoExtrapolate(out, in, chronoResident[param->chrono_index], mChrono, hermitian); - } - - // if using a Schwarz preconditioner with a normal operator then we must use the DiracMdagMLocal operator - if (param->inv_type_precondition != QUDA_INVALID_INVERTER && param->schwarz_type != QUDA_INVALID_SCHWARZ) { - DiracMdagMLocal mPreLocal(diracPre); - Solver *solve = Solver::create(solverParam, m, mSloppy, mPreLocal, mEig); - (*solve)(out, in); - delete solve; - solverParam.updateInvertParam(*param); - } else { - Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); - (*solve)(out, in); - delete solve; - solverParam.updateInvertParam(*param); - } - } else { // norm_error_solve - DiracMMdag m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); - ColorSpinorField tmp(out); - SolverParam solverParam(*param); - Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); - (*solve)(tmp, in); // y = (M M^\dag) b - dirac.Mdag(out, tmp); // x = M^dag y - delete solve; - solverParam.updateInvertParam(*param); - } - - logQuda(QUDA_VERBOSE, "Solution = %g\n", blas::norm2(x)); - - profileInvert.TPSTART(QUDA_PROFILE_EPILOGUE); - if (param->chrono_make_resident) { - const int i = param->chrono_index; - if (i >= QUDA_MAX_CHRONO) - errorQuda("Requested chrono index %d is outside of max %d\n", i, QUDA_MAX_CHRONO); - - auto &basis = chronoResident[i]; - - if (param->chrono_max_dim < (int)basis.size()) { - errorQuda("Requested chrono_max_dim %i is smaller than already existing chronology %lu", param->chrono_max_dim, basis.size()); - } - - if(not param->chrono_replace_last){ - // if we have not filled the space yet just augment - if ((int)basis.size() < param->chrono_max_dim) { - ColorSpinorParam cs_param(out); - cs_param.setPrecision(param->chrono_precision); - basis.emplace_back(cs_param); - } - - // shuffle every entry down one and bring the last to the front - std::rotate(basis.begin(), basis.end() - 1, basis.end()); - } - basis[0] = out; // set first entry to new solution - } - dirac.reconstruct(x, b, param->solution_type); - - distanceReweight(x, *param, false); - - if (param->solver_normalization == QUDA_SOURCE_NORMALIZATION) { - // rescale the solution - blas::ax(sqrt(nb), x); - } - - if (param->compute_action) { - Complex action = blas::cDotProduct(b, x); - param->action[0] = action.real(); - param->action[1] = action.imag(); - } - - profileInvert.TPSTOP(QUDA_PROFILE_EPILOGUE); - - if (!param->make_resident_solution) h_x = x; - - logQuda(QUDA_VERBOSE, "Reconstructed solution: %g\n", blas::norm2(x)); + solve({hp_x}, {hp_b}, *param, *cudaGauge); if (param->use_resident_solution && !param->make_resident_solution) solutionResident.clear(); - delete d; - delete dSloppy; - delete dPre; - delete dEig; - - // cache is written out even if a long benchmarking job gets interrupted - saveTuneCache(); - profilerStop(__func__); popVerbosity(); } @@ -3431,7 +3067,7 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col */ profilerStart(__func__); - auto profile = pushProfile(profileInvertMultiSrc, param->secs, param->gflops); + auto profile = pushProfile(profileInvertMultiSrc, param); CommKey split_key = {param->split_grid[0], param->split_grid[1], param->split_grid[2], param->split_grid[3]}; int num_sub_partition = quda::product(split_key); @@ -3440,9 +3076,14 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col errorQuda("split_key = [%d,%d,%d,%d] is not valid", split_key[0], split_key[1], split_key[2], split_key[3]); } + checkInvertParam(param, _hp_x[0], _hp_b[0]); + if (num_sub_partition == 1) { // In this case we don't split the grid. - for (int n = 0; n < param->num_src; n++) { op(_hp_x[n], _hp_b[n], param, args...); } + std::vector x(param->num_src), b(param->num_src); + for (auto i = 0u; i < x.size(); i++) x[i] = _hp_x[i]; + for (auto i = 0u; i < b.size(); i++) b[i] = _hp_b[i]; + op(x, b, *param, args...); } else { @@ -3648,23 +3289,30 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col // Since input fields are in Native order now param_copy.dirac_order = QUDA_INTERNAL_DIRAC_ORDER; - // We need to set the cpu_prec in the param_copy, because the op() passed in - // to us will try to create wrappers to the pointers we pass in. They expect - // the input spinors to be on the host, and will use param_copy.cpu_prec to set - // the precision. We want to avoid the situation, where the internal prec and the - // cpu_prec are somehow different. - param_copy.cpu_prec = _collect_b[0].Precision(); + // We need to set the cpu_prec in the param_copy, because the op() passed in + // to us will try to create wrappers to the pointers we pass in. They expect + // the input spinors to be on the host, and will use param_copy.cpu_prec to set + // the precision. We want to avoid the situation, where the internal prec and the + // cpu_prec are somehow different. + param_copy.cpu_prec = _collect_b[0].Precision(); // Do the solves - for (int n = 0; n < param->num_src_per_sub_partition; n++) { - op(_collect_x[n].data(), _collect_b[n].data(), ¶m_copy, args...); - } + std::vector x_raw(param->num_src_per_sub_partition); + std::vector b_raw(param->num_src_per_sub_partition); + for (auto i = 0u; i < x_raw.size(); i++) x_raw[i] = _collect_x[i].data(); + for (auto i = 0u; i < b_raw.size(); i++) b_raw[i] = _collect_b[i].data(); + op(x_raw, b_raw, param_copy, args...); + + auto split_rank = comm_rank(); profileInvertMultiSrc.TPSTART(QUDA_PROFILE_EPILOGUE); push_communicator(default_comm_key); updateR(); comm_barrier(); + // back to the default communicator, now join the param entries + joinInvertParam(*param, param_copy, split_key, split_rank); + // Join spinors: _h_x are aliases to host pointers in 'external order: QDP++, QDP-JIT, etc' for (int n = 0; n < param->num_src_per_sub_partition; n++) { // join fields @@ -3722,14 +3370,19 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col void invertMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param) { - auto op = [](void *_x, void *_b, QudaInvertParam *param) { invertQuda(_x, _b, param); }; + auto op = [](const std::vector &_x, const std::vector &_b, QudaInvertParam ¶m) { + // check the gauge fields have been created + GaugeField *gauge = checkGauge(¶m); + solve(_x, _b, param, *gauge); + }; callMultiSrcQuda(_hp_x, _hp_b, param, op); } - void dslashMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, QudaParity parity) { - auto op = [](void *_x, void *_b, QudaInvertParam *param, QudaParity parity) { dslashQuda(_x, _b, param, parity); }; + auto op = [](const std::vector &_x, const std::vector &_b, QudaInvertParam ¶m, QudaParity parity) { + for (auto i = 0u; i < _b.size(); i++) dslashQuda(_x[i], _b[i], ¶m, parity); + }; callMultiSrcQuda(_hp_x, _hp_b, param, op, parity); } @@ -3747,7 +3400,7 @@ void dslashMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, Quda */ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) { - auto profile = pushProfile(profileMulti, param->secs, param->gflops); + auto profile = pushProfile(profileMulti, param); profilerStart(__func__); if (!initialized) errorQuda("QUDA not initialized"); @@ -4014,14 +3667,14 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) { CG cg(*m, *mSloppy, *mSloppy, *mSloppy, solverParam); - if (i==0) - cg(x[i], b, &p[i], r2_old[i]); + if (i == 0) + cg(x[i], b, p[i], r2_old[i]); else cg(x[i], b); } - solverParam.true_res_offset[i] = solverParam.true_res; - solverParam.true_res_hq_offset[i] = solverParam.true_res_hq; + solverParam.true_res_offset[i] = static_cast(solverParam.true_res); + solverParam.true_res_hq_offset[i] = static_cast(solverParam.true_res_hq); solverParam.updateInvertParam(*param,i); if (param->dslash_type == QUDA_ASQTAD_DSLASH || @@ -4066,9 +3719,6 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) delete dPre; delete dRefine; - // cache is written out even if a long benchmarking job gets interrupted - saveTuneCache(); - profilerStop(__func__); popVerbosity(); } @@ -4887,7 +4537,7 @@ void computeCloverForceQuda(void *h_mom, double dt, void **h_x, void **, double QudaInvertParam *inv_param) { using namespace quda; - auto profile = pushProfile(profileCloverForce, inv_param->secs, inv_param->gflops); + auto profile = pushProfile(profileCloverForce, inv_param); checkGaugeParam(gauge_param); if (!gaugePrecise) errorQuda("No resident gauge field"); @@ -4941,6 +4591,8 @@ void computeCloverForceQuda(void *h_mom, double dt, void **h_x, void **, double // Make sure extendedGaugeResident has the correct R if (extendedGaugeResident) delete extendedGaugeResident; + lat_dim_t R; + for (int d = 0; d < 4; d++) R[d] = (d == 0 ? 2 : 1) * (redundant_comms || commDimPartitioned(d)); extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, getProfile()); GaugeField &gaugeEx = *extendedGaugeResident; @@ -4959,7 +4611,7 @@ void computeTMCloverForceQuda(void *h_mom, void **h_x, void **h_x0, double *coef QudaGaugeParam *gauge_param, QudaInvertParam *inv_param, int detratio) { using namespace quda; - auto profile = pushProfile(profileTMCloverForce, inv_param->secs, inv_param->gflops); + auto profile = pushProfile(profileTMCloverForce, inv_param); checkGaugeParam(gauge_param); if (!gaugePrecise) errorQuda("No resident gauge field"); @@ -5014,6 +4666,8 @@ void computeTMCloverForceQuda(void *h_mom, void **h_x, void **h_x0, double *coef // Make sure extendedGaugeResident has the correct R if (extendedGaugeResident) delete extendedGaugeResident; + lat_dim_t R; + for (int d = 0; d < 4; d++) R[d] = (d == 0 ? 2 : 1) * (redundant_comms || commDimPartitioned(d)); extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, profileTMCloverForce); GaugeField &gaugeEx = *extendedGaugeResident; @@ -5341,7 +4995,7 @@ void performWuppertalnStep(void *h_out, void *h_in, QudaInvertParam *inv_param, void performTwoLinkGaussianSmearNStep(void *h_in, QudaQuarkSmearParam *smear_param) { if (smear_param->n_steps == 0) return; - auto profile = pushProfile(profileGaussianSmear, smear_param->secs, smear_param->gflops); + auto profile = pushProfile(profileGaussianSmear, smear_param); QudaInvertParam *inv_param = smear_param->inv_param; @@ -5450,8 +5104,6 @@ void performTwoLinkGaussianSmearNStep(void *h_in, QudaQuarkSmearParam *smear_par delete d; if (smear_param->delete_2link != 0) { freeUniqueGaugeQuda(QUDA_SMEARED_LINKS); } - - saveTuneCache(); } void performGaugeSmearQuda(QudaGaugeSmearParam *smear_param, QudaGaugeObservableParam *obs_param) @@ -5725,8 +5377,6 @@ void contractFTQuda(void **prop_array_flavor_1, void **prop_array_flavor_2, void } } profileContractFT.TPSTOP(QUDA_PROFILE_COMPUTE); - - saveTuneCache(); } void contractQuda(const void *hp_x, const void *hp_y, void *h_result, const QudaContractType cType, @@ -5817,7 +5467,7 @@ static void check_param(double _Complex *host_sinks, void **host_quark, int n_qu void laphSinkProject(double _Complex *host_sinks, void **host_quark, int n_quark, int tile_quark, void **host_evec, int n_evec, int tile_evec, QudaInvertParam *inv_param, const int X[4]) { - auto profile = pushProfile(profileSinkProject, inv_param->secs, inv_param->gflops); + auto profile = pushProfile(profileSinkProject, inv_param); // check parameters are valid check_param(host_sinks, host_quark, n_quark, tile_quark, host_evec, n_evec, tile_evec, inv_param, X); diff --git a/lib/inv_bicgstab_quda.cpp b/lib/inv_bicgstab_quda.cpp index 7762ad1c10..2cffc2aeed 100644 --- a/lib/inv_bicgstab_quda.cpp +++ b/lib/inv_bicgstab_quda.cpp @@ -19,28 +19,49 @@ namespace quda { BiCGstab::~BiCGstab() { destroyDeflationSpace(); } - void BiCGstab::create(ColorSpinorField &x, const ColorSpinorField &b) + void BiCGstab::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { getProfile().TPSTART(QUDA_PROFILE_INIT); - ColorSpinorParam csParam(x); + ColorSpinorParam csParam(x[0]); csParam.create = QUDA_ZERO_FIELD_CREATE; - y = ColorSpinorField(csParam); - r = ColorSpinorField(csParam); + resize(y, b.size(), csParam); + resize(r, b.size(), csParam); csParam.setPrecision(param.precision_sloppy); - p = ColorSpinorField(csParam); - v = ColorSpinorField(csParam); - t = ColorSpinorField(csParam); + resize(p, b.size(), csParam); + resize(v, b.size(), csParam); + resize(t, b.size(), csParam); + + if (param.precision_sloppy == x.Precision()) { + create_alias(r_sloppy, r); + if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + create_alias(r0, b); + } else { + csParam.create = QUDA_NULL_FIELD_CREATE; + resize(r0, b.size(), csParam); + blas::copy(r0, r); + } + } else { + csParam.create = QUDA_NULL_FIELD_CREATE; + resize(r_sloppy, b.size(), csParam); + resize(r0, b.size(), csParam); + } + + if (param.precision_sloppy == x.Precision() || !param.use_sloppy_partial_accumulator) { + create_alias(x_sloppy, x); + } else { + resize(x_sloppy, b.size(), csParam); + } - getProfile().TPSTOP(QUDA_PROFILE_INIT); init = true; + getProfile().TPSTOP(QUDA_PROFILE_INIT); } // init } - ColorSpinorField &BiCGstab::get_residual() + cvector_ref BiCGstab::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); @@ -56,27 +77,31 @@ namespace quda { //int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0 int updateR = (rNorm < delta*maxrr) ? 1 : 0; - //printf("reliable %d %e %e %e %e\n", updateR, rNorm, maxrx, maxrr, r2); - return updateR; } - void BiCGstab::operator()(ColorSpinorField &x, ColorSpinorField &b) + void BiCGstab::operator()(cvector_ref &x, cvector_ref &b) { create(x, b); getProfile().TPSTART(QUDA_PROFILE_INIT); - double b2 = blas::norm2(b); // norm sq of source - double r2 = 0.0; // norm sq of residual + auto b2 = blas::norm2(b); // norm sq of source + vector r2(b.size(), 0.0); // norm sq of residual + + // Check to see that we're not trying to invert on a zero-field source + if (is_zero_src(x, b, b2)) { + getProfile().TPSTOP(QUDA_PROFILE_INIT); + return; + } if (param.deflate) { // Construct the eigensolver and deflation space if requested. if (param.eig_param.eig_type == QUDA_EIG_TR_LANCZOS || param.eig_param.eig_type == QUDA_EIG_BLK_TR_LANCZOS) { - constructDeflationSpace(b, matMdagM); + constructDeflationSpace(b[0], matMdagM); } else { // Use Arnoldi to inspect the space only and turn off deflation - constructDeflationSpace(b, mat); + constructDeflationSpace(b[0], mat); param.deflate = false; } if (deflate_compute) { @@ -101,7 +126,7 @@ namespace quda { if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { mat(r, x); r2 = blas::xmyNorm(b, r); - blas::copy(y, x); + for (auto i = 0u; i < x.size(); i++) std::swap(y[i], x[i]); } else { blas::copy(r, b); r2 = b2; @@ -117,62 +142,24 @@ namespace quda { r2 = blas::xmyNorm(b, r); } - // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0) { - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - warningQuda("inverting on zero-field source"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; - getProfile().TPSTOP(QUDA_PROFILE_INIT); - return; - } else if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { - b2 = r2; - } else { - errorQuda("Null vector computing requires non-zero guess!"); - } - } - - // set field aliasing according to whether we are doing mixed precision or not - if (param.precision_sloppy == x.Precision()) { - r_sloppy = r.create_alias(); - - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - r0 = b.create_alias(); - } else { - ColorSpinorParam csParam(r); - csParam.create = QUDA_NULL_FIELD_CREATE; - r0 = ColorSpinorField(csParam); - blas::copy(r0, r); - } - } else { - ColorSpinorParam csParam(x); - csParam.setPrecision(param.precision_sloppy); - csParam.create = QUDA_NULL_FIELD_CREATE; - r_sloppy = ColorSpinorField(csParam); + if (param.precision != param.precision_sloppy) { blas::copy(r_sloppy, r); - r0 = ColorSpinorField(csParam); - blas::copy(r0, r); - } - - if (param.precision_sloppy == x.Precision() || !param.use_sloppy_partial_accumulator) { - x_sloppy = x.create_alias(); - blas::zero(x_sloppy); - } else { - ColorSpinorParam csParam(x); - csParam.create = QUDA_ZERO_FIELD_CREATE; - csParam.setPrecision(param.precision_sloppy); - x_sloppy = ColorSpinorField(csParam); + blas::copy(r0, param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO ? b : r); } getProfile().TPSTOP(QUDA_PROFILE_INIT); getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); - double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop_hq = std::vector(b.size(), param.tol_hq); const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; - double heavy_quark_res = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x,r).z) : 0.0; + vector heavy_quark_res(b.size(), 0.0); + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual double delta = param.delta; @@ -180,16 +167,15 @@ namespace quda { int k = 0; int rUpdate = 0; - Complex rho(1.0, 0.0); - Complex rho0 = rho; - Complex alpha(1.0, 0.0); - Complex omega(1.0, 0.0); - Complex beta; + vector rho(b.size(), {1.0, 0.0}); + vector rho0 = rho; + vector alpha(b.size(), {1.0, 0.0}); + vector omega(b.size(), {1.0, 0.0}); + vector beta(b.size()); - double3 rho_r2; - double3 omega_t2; + vector rho_r2(b.size()); - double rNorm = sqrt(r2); + double rNorm = sqrt(r2[0]); //double r0Norm = rNorm; double maxrr = rNorm; double maxrx = rNorm; @@ -202,10 +188,7 @@ namespace quda { rho = r2; // cDotProductCuda(r0, r_sloppy); // BiCRstab blas::copy(p, r_sloppy); - bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); - - logQuda(QUDA_DEBUG_VERBOSE, "BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, r0=%e, t2=%e\n", blas::norm2(x), - blas::norm2(r_sloppy), blas::norm2(v), blas::norm2(p), blas::norm2(r0), blas::norm2(t)); + bool converged = convergence(r2, heavy_quark_res, stop, stop_hq); // track if we just performed an exact recalculation of y, r, r2 bool just_updated = false; @@ -215,15 +198,19 @@ namespace quda { matSloppy(v, p); - Complex r0v; + vector r0v; if (param.pipeline) { r0v = blas::cDotProduct(r0, v); if (k > 0) rho = blas::cDotProduct(r0, r); } else { r0v = blas::cDotProduct(r0, v); } - if (abs(rho) == 0.0) alpha = 0.0; - else alpha = rho / r0v; + for (auto i = 0u; i < b.size(); i++) { + if (abs(rho[i]) == 0.0) + alpha[i] = 0.0; + else + alpha[i] = rho[i] / r0v[i]; + } // r -= alpha*v blas::caxpy(-alpha, v, r_sloppy); @@ -233,55 +220,63 @@ namespace quda { int updateR = 0; if (param.pipeline) { // omega = (t, r) / (t, t) - omega_t2 = blas::cDotProductNormA(t, r_sloppy); - Complex tr = Complex(omega_t2.x, omega_t2.y); - double t2 = omega_t2.z; - omega = tr / t2; - double s2 = blas::norm2(r_sloppy); - Complex r0t = blas::cDotProduct(r0, t); - beta = -r0t / r0v; - r2 = s2 - real(omega * conj(tr)); + auto omega_t2_s2 = blas::cDotProductNormAB(t, r_sloppy); + auto r0t = blas::cDotProduct(r0, t); + + for (auto i = 0u; i < b.size(); i++) { + omega[i] = Complex {omega_t2_s2[i].x, omega_t2_s2[i].y} / omega_t2_s2[i].z; + beta[i] = -r0t[i] / r0v[i]; + r2[i] = omega_t2_s2[i].w - real(omega[i] * conj(Complex {omega_t2_s2[i].x, omega_t2_s2[i].y})); + } // now we can work out if we need to do a reliable update - updateR = reliable(rNorm, maxrx, maxrr, r2, delta); + updateR = reliable(rNorm, maxrx, maxrr, r2[0], delta); } else { // omega = (t, r) / (t, t) - omega_t2 = blas::cDotProductNormA(t, r_sloppy); - omega = Complex(omega_t2.x / omega_t2.z, omega_t2.y / omega_t2.z); + auto omega_t2 = blas::cDotProductNormA(t, r_sloppy); + for (auto i = 0u; i < b.size(); i++) + omega[i] = Complex(omega_t2[i].x / omega_t2[i].z, omega_t2[i].y / omega_t2[i].z); } if (param.pipeline && !updateR) { // x += alpha*p + omega*r, r -= omega*t, p = r - beta*omega*v + beta*p blas::caxpbypzYmbw(alpha, p, omega, r_sloppy, x_sloppy, t); - blas::cxpaypbz(r_sloppy, -beta * omega, v, beta, p); + vector beta_omega(b.size()); + for (auto i = 0u; i < b.size(); i++) beta_omega[i] = -beta[i] * omega[i]; + blas::cxpaypbz(r_sloppy, beta_omega, v, beta, p); // tripleBiCGstabUpdate(alpha, p, omega, r_sloppy, x_sloppy, t, -beta*omega, v, beta, p } else { // x += alpha*p + omega*r, r -= omega*t, r2 = (r,r), rho = (r0, r) rho_r2 = blas::caxpbypzYmbwcDotProductUYNormY(alpha, p, omega, r_sloppy, x_sloppy, t, r0); rho0 = rho; - rho = Complex(rho_r2.x, rho_r2.y); - r2 = rho_r2.z; + for (auto i = 0u; i < b.size(); i++) { + rho[i] = Complex(rho_r2[i].x, rho_r2[i].y); + r2[i] = rho_r2[i].z; + } } if (use_heavy_quark_res && k % heavy_quark_check == 0) { - if (&x != &x_sloppy) { - heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x_sloppy, r_sloppy).z); + vector hq; + + if (x.Precision() != x_sloppy[0].Precision()) { + hq = blas::HeavyQuarkResidualNorm(x_sloppy, r_sloppy); } else { blas::copy(r, r_sloppy); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z); + hq = blas::xpyHeavyQuarkResidualNorm(x, y, r); } + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); } - if (!param.pipeline) updateR = reliable(rNorm, maxrx, maxrr, r2, delta); + if (!param.pipeline) updateR = reliable(rNorm, maxrx, maxrr, r2[0], delta); if (updateR) { - if (x.Precision() != x_sloppy.Precision()) blas::copy(x, x_sloppy); + if (x.Precision() != x_sloppy[0].Precision()) blas::copy(x, x_sloppy); blas::xpy(x, y); mat(r, y); r2 = blas::xmyNorm(b, r); - if (param.deflate && sqrt(r2) < param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < param.tol_restart) { // Deflate and accumulate to solution vector eig_solve->deflate(y, r, evecs, evals, true); @@ -290,10 +285,10 @@ namespace quda { r2 = blas::xmyNorm(b, r); } - if (x.Precision() != r_sloppy.Precision()) blas::copy(r_sloppy, r); + if (param.precision != param.precision_sloppy) blas::copy(r_sloppy, r); blas::zero(x_sloppy); - rNorm = sqrt(r2); + rNorm = sqrt(r2[0]); maxrr = rNorm; maxrx = rNorm; // r0Norm = rNorm; @@ -305,20 +300,17 @@ namespace quda { k++; PrintStats("BiCGstab", k, r2, b2, heavy_quark_res); - logQuda(QUDA_DEBUG_VERBOSE, "BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, r0=%e, t2=%e\n", blas::norm2(x), - blas::norm2(r_sloppy), blas::norm2(v), blas::norm2(p), blas::norm2(r0), blas::norm2(t)); - - converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + converged = convergence(r2, heavy_quark_res, stop, stop_hq); if (converged) { // make sure we've truly converged if (!just_updated) { - if (x.Precision() != x_sloppy.Precision()) blas::copy(x, x_sloppy); + if (x.Precision() != x_sloppy[0].Precision()) blas::copy(x, x_sloppy); blas::xpy(x, y); mat(r, y); r2 = blas::xmyNorm(b, r); - if (param.deflate && sqrt(r2) < param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < param.tol_restart) { // Deflate and accumulate to solution vector eig_solve->deflate(y, r, evecs, evals, true); // Compute r_defl = RHS - A * LHS @@ -326,10 +318,10 @@ namespace quda { r2 = blas::xmyNorm(b, r); } - if (x.Precision() != r_sloppy.Precision()) blas::copy(r_sloppy, r); + if (r[0].Precision() != r_sloppy[0].Precision()) blas::copy(r_sloppy, r); blas::zero(x_sloppy); - rNorm = sqrt(r2); + rNorm = sqrt(r2[0]); maxrr = rNorm; maxrx = rNorm; // r0Norm = rNorm; @@ -339,19 +331,26 @@ namespace quda { } // explicitly compute the HQ residual if need be - heavy_quark_res = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(y, r).z) : 0.0; + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(y, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res = sqrt(hq[i].z); + } // Update convergence check - converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + converged = convergence(r2, heavy_quark_res, stop, stop_hq); } // update p if ((!param.pipeline || updateR) && !converged) { // need to update if not pipeline or did a reliable update - if (abs(rho * alpha) == 0.0) - beta = 0.0; - else - beta = (rho / rho0) * (alpha / omega); - blas::cxpaypbz(r_sloppy, -beta * omega, v, beta, p); + vector beta_omega(b.size()); + for (auto i = 0u; i < b.size(); i++) { + if (abs(rho[i] * alpha[i]) == 0.0) + beta[i] = 0.0; + else + beta[i] = (rho[i] / rho0[i]) * (alpha[i] / omega[i]); + beta_omega[i] = -beta[i] * omega[i]; + } + blas::cxpaypbz(r_sloppy, beta_omega, v, beta, p); } } @@ -369,12 +368,14 @@ namespace quda { logQuda(QUDA_VERBOSE, "BiCGstab: Reliable updates = %d\n", rUpdate); - if (!param.is_preconditioner) { // do not do the below if we this is an inner solver + if (!param.is_preconditioner) { // do not do the below if this is an inner solver // r2 was freshly computed - param.true_res = sqrt(r2 / b2); - param.true_res_hq = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x,r).z) : 0.0; - - PrintSummary("BiCGstab", k, r2, b2, stop, param.tol_hq); + auto hq = use_heavy_quark_res ? blas::HeavyQuarkResidualNorm(x, r) : vector(b.size(), {}); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(r2[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } + PrintSummary("BiCGstab", k, r2, b2, stop, stop_hq); } getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); diff --git a/lib/inv_bicgstabl_quda.cpp b/lib/inv_bicgstabl_quda.cpp index cf5c248297..98176af28a 100644 --- a/lib/inv_bicgstabl_quda.cpp +++ b/lib/inv_bicgstabl_quda.cpp @@ -22,9 +22,17 @@ namespace quda { + auto get_i = [](std::vector> &p, int i) { + vector_ref p_i; + p_i.reserve(p.size()); + for (auto &pi : p) p_i.push_back(pi[i]); + return p_i; + }; + #ifndef LEGACY_MR // Compute the MR portion of BiCGstab-L - void BiCGstabL::computeMR(ColorSpinorField &x_sloppy, bool fixed_iteration) + void BiCGstabL::computeMR(ColorSpinorField &x_sloppy, cvector_ref &u, + cvector_ref &r, bool fixed_iteration, int src_idx) { using matrix = Matrix, Dynamic, Dynamic, RowMajor>; using vector = Matrix, Dynamic, 1>; @@ -64,12 +72,11 @@ namespace quda { } // Update omega for the next BiCG iteration - omega = gamma(n_krylov - 1); + omega[src_idx] = gamma(n_krylov - 1); std::vector gamma_(n_krylov); if (!fixed_iteration) { - // update u { // u = u[0] - \sum_{j=1}^L \gamma_L u_L @@ -100,39 +107,46 @@ namespace quda { } } - void BiCGstabL::computeTau(int, int, int) + void BiCGstabL::computeTau(int, int, int, cvector_ref &, int) { errorQuda("Legacy MR path in BiCGstab-L called with a non-legacy compile"); } - void BiCGstabL::updateR(int, int, int) { errorQuda("Legacy MR path in BiCGstab-L called with a non-legacy compile"); } + void BiCGstabL::updateR(int, int, int, cvector_ref &, int) + { + errorQuda("Legacy MR path in BiCGstab-L called with a non-legacy compile"); + } - void BiCGstabL::legacyComputeMR(ColorSpinorField &) + void BiCGstabL::legacyComputeMR(ColorSpinorField &, cvector_ref &, cvector_ref &, + int) { errorQuda("Legacy MR path in BiCGstab-L called with a non-legacy compile"); } #else - void BiCGstabL::computeMR(ColorSpinorField &, bool) + void BiCGstabL::computeMR(ColorSpinorField &, cvector_ref &, cvector_ref &, bool, + int) { errorQuda("Non-legacy MR path in BiCGstab-L called with a legacy compile"); } // Utility functions for Gram-Schmidt. Based on GCR functions. // Big change is we need to go from 1 to n_krylov, not 0 to n_krylov-1. - void BiCGstabL::computeTau(int begin, int size, int j) + void BiCGstabL::computeTau(int begin, int size, int j, cvector_ref &r, int src_idx) { std::vector Tau(size); blas::block::cDotProduct(Tau, {r.begin() + begin, r.begin() + begin + size}, r[j]); // vectorized dot product - for (int k = 0; k < size; k++) { tau[(begin + k) * (n_krylov + 1) + j] = Tau[k] / sigma[begin + k]; } + for (int k = 0; k < size; k++) { + tau[src_idx][(begin + k) * (n_krylov + 1) + j] = Tau[k] / sigma[src_idx][begin + k]; + } } - void BiCGstabL::updateR(int begin, int size, int j) + void BiCGstabL::updateR(int begin, int size, int j, cvector_ref &r, int src_idx) { std::vector tau_(size); - for (int i = 0; i < size; i++) { tau_[i] = -tau[(i + begin) * (n_krylov + 1) + j]; } + for (int i = 0; i < size; i++) { tau_[i] = -tau[src_idx][(i + begin) * (n_krylov + 1) + j]; } auto r_ = {r.begin() + begin, r.begin() + begin + size}; auto rj = {r.begin() + j, r.begin() + j + 1}; @@ -144,7 +158,8 @@ namespace quda { Legacy routine for the original pipelined Gram-Schmit See "The MR part" in https://etna.math.kent.edu/vol.1.1993/pp11-32.dir/pp11-32.pdf */ - void BiCGstabL::legacyComputeMR(ColorSpinorField &x_sloppy) + void BiCGstabL::legacyComputeMR(ColorSpinorField &x_sloppy, cvector_ref &u, + cvector_ref &r, int src_idx) { // MR part. Really just modified Gram-Schmidt. // The algorithm uses the byproducts of the Gram-Schmidt to update x @@ -157,8 +172,8 @@ namespace quda { case 0: // no kernel fusion for (int i = 1; i < j; i++) // 5 (j-2) memory transactions here. Start at 1 b/c bicgstabl convention. { - tau[i * (n_krylov + 1) + j] = blas::cDotProduct(r[i], r[j]) / sigma[i]; - blas::caxpy(-tau[i * (n_krylov + 1) + j], r[i], r[j]); + tau[src_idx][i * (n_krylov + 1) + j] = blas::cDotProduct(r[i], r[j]) / sigma[src_idx][i]; + blas::caxpy(-tau[src_idx][i * (n_krylov + 1) + j], r[i], r[j]); } break; case 1: // basic kernel fusion @@ -166,13 +181,13 @@ namespace quda { { break; } - tau[1 * (n_krylov + 1) + j] = blas::cDotProduct(r[1], r[j]) / sigma[1]; + tau[src_idx][1 * (n_krylov + 1) + j] = blas::cDotProduct(r[1], r[j]) / sigma[src_idx][1]; for (int i = 1; i < j - 1; i++) // 4 (j-2) memory transactions here. start at 1. { - tau[(i + 1) * (n_krylov + 1) + j] - = blas::caxpyDotzy(-tau[i * (n_krylov + 1) + j], r[i], r[j], r[i + 1]) / sigma[i + 1]; + auto dot = blas::caxpyDotzy(-tau[src_idx][i * (n_krylov + 1) + j], r[i], r[j], r[i + 1]); + tau[src_idx][(i + 1) * (n_krylov + 1) + j] = dot / sigma[src_idx][i + 1]; } - blas::caxpy(-tau[(j - 1) * (n_krylov + 1) + j], r[j - 1], r[j]); + blas::caxpy(-tau[src_idx][(j - 1) * (n_krylov + 1) + j], r[j - 1], r[j]); break; default: { const int N = pipeline; @@ -183,15 +198,15 @@ namespace quda { // (j-1)/N updates of length N, at 1,1+N,1+2*N,... int step; for (step = 0; step < (j - 1) / N; step++) { - computeTau(1 + step * N, N, j); - updateR(1 + step * N, N, j); + computeTau(1 + step * N, N, j, r, src_idx); + updateR(1 + step * N, N, j, r, src_idx); } if ((j - 1) % N != 0) // need to update the remainder { // 1 update of length (j-1)%N. - computeTau(1 + step * N, (j - 1) % N, j); - updateR(1 + step * N, (j - 1) % N, j); + computeTau(1 + step * N, (j - 1) % N, j, r, src_idx); + updateR(1 + step * N, (j - 1) % N, j, r, src_idx); } } break; } @@ -199,27 +214,30 @@ namespace quda { // sigma_j = r_j^2, gamma'_j = /sigma_j // rjr.x = Re(), rjr.z = - double3 rjr = blas::cDotProductNormA(r[j], r[0]); - sigma[j] = rjr.z; - gamma_prime[j] = Complex(rjr.x, rjr.y) / sigma[j]; + auto rjr = blas::cDotProductNormA(r[j], r[0]); + sigma[src_idx][j] = rjr.z; + gamma_prime[src_idx][j] = Complex(rjr.x, rjr.y) / sigma[src_idx][j]; } // gamma[n_krylov] = gamma'[n_krylov], omega = gamma[n_krylov] - gamma[n_krylov] = gamma_prime[n_krylov]; - omega = gamma[n_krylov]; + gamma[src_idx][n_krylov] = gamma_prime[src_idx][n_krylov]; + omega[src_idx] = gamma[src_idx][n_krylov]; // gamma = T^(-1) gamma_prime. It's in the paper, I promise. for (int j = n_krylov - 1; j > 0; j--) { // Internal def: gamma[j] = gamma'_j - \sum_{i = j+1 to n_krylov} tau_ji gamma_i - gamma[j] = gamma_prime[j]; - for (int i = j + 1; i <= n_krylov; i++) { gamma[j] = gamma[j] - tau[j * (n_krylov + 1) + i] * gamma[i]; } + gamma[src_idx][j] = gamma_prime[src_idx][j]; + for (int i = j + 1; i <= n_krylov; i++) { + gamma[src_idx][j] = gamma[src_idx][j] - tau[src_idx][j * (n_krylov + 1) + i] * gamma[src_idx][i]; + } } // gamma'' = T S gamma. Check paper for defn of S. for (int j = 1; j < n_krylov; j++) { - gamma_prime_prime[j] = gamma[j + 1]; + gamma_prime_prime[src_idx][j] = gamma[src_idx][j + 1]; for (int i = j + 1; i < n_krylov; i++) { - gamma_prime_prime[j] = gamma_prime_prime[j] + tau[j * (n_krylov + 1) + i] * gamma[i + 1]; + gamma_prime_prime[src_idx][j] + = gamma_prime_prime[src_idx][j] + tau[src_idx][j * (n_krylov + 1) + i] * gamma[src_idx][i + 1]; } } @@ -229,7 +247,7 @@ namespace quda { // Update U { std::vector gamma_(n_krylov); - for (int i = 0; i < n_krylov; i++) { gamma_[i] = -gamma[i + 1]; } + for (int i = 0; i < n_krylov; i++) { gamma_[i] = -gamma[src_idx][i + 1]; } blas::block::caxpy(gamma_, {u.begin() + 1, u.end()}, u[0]); } @@ -241,15 +259,15 @@ namespace quda { // the full precision, this can be a killer. std::vector gamma_prime_prime_(n_krylov + 1); std::vector gamma_prime_(n_krylov + 1); - gamma_prime_prime_[0] = gamma[1]; + gamma_prime_prime_[0] = gamma[src_idx][1]; gamma_prime_prime_[n_krylov] = 0.0; // x never gets updated with r[n_krylov] gamma_prime_[0] = 0.0; // r[0] never gets updated with r[0]... obvs. - gamma_prime_[n_krylov] = -gamma_prime[n_krylov]; + gamma_prime_[n_krylov] = -gamma_prime[src_idx][n_krylov]; for (int i = 1; i < n_krylov; i++) { - gamma_prime_prime_[i] = gamma_prime_prime[i]; - gamma_prime_[i] = -gamma_prime[i]; + gamma_prime_prime_[i] = gamma_prime_prime[src_idx][i]; + gamma_prime_[i] = -gamma_prime[src_idx][i]; } - blas::caxpyBxpz(gamma_prime_prime_, r, x_sloppy, gamma_prime_, r[0]); + blas::block::caxpyBxpz(gamma_prime_prime_, r, x_sloppy, gamma_prime_, r[0]); } } @@ -273,12 +291,12 @@ namespace quda { class BiCGstabLUpdate : public Worker { - ColorSpinorField &x; - std::vector &r; - std::vector &u; + std::vector &x; + std::vector> &r; + std::vector> &u; - Complex α - Complex β + std::vector α + std::vector β BiCGstabLUpdateType update_type; @@ -296,13 +314,13 @@ namespace quda { int n_update; public: - BiCGstabLUpdate(ColorSpinorField &x, std::vector &r, std::vector &u, - Complex &alpha, Complex &beta, BiCGstabLUpdateType update_type, int j_max, int n_update) : + BiCGstabLUpdate(std::vector &x, std::vector> &r, + std::vector> &u, std::vector &alpha, + std::vector &beta, BiCGstabLUpdateType update_type, int j_max, int n_update) : x(x), r(r), u(u), alpha(alpha), beta(beta), update_type(update_type), j_max(j_max), n_update(n_update) { } - virtual ~BiCGstabLUpdate() { } void update_j_max(int new_j_max) { j_max = new_j_max; } void update_update_type(BiCGstabLUpdateType new_update_type) { update_type = new_update_type; } @@ -317,7 +335,7 @@ namespace quda { if (update_type == BICGSTABL_UPDATE_U) { for (int i = (count * j_max) / n_update; i < ((count + 1) * j_max) / n_update && i < j_max; i++) { - blas::caxpby(1.0, r[i], -beta, u[i]); + for (auto j = 0u; j < beta.size(); j++) blas::caxpby(1.0, r[i][j], -beta[j], u[i][j]); } } else // (update_type == BICGSTABL_UPDATE_R) @@ -325,9 +343,8 @@ namespace quda { if (count == 0) { blas::caxpy(alpha, u[0], x); } if (j_max > 0) { - for (int i= (count*j_max)/n_update; i<((count+1)*j_max)/n_update && i &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { + if (!init || y.size() != b.size()) { getProfile().TPSTART(QUDA_PROFILE_INIT); // Initialize fields. - ColorSpinorParam csParam(x); + ColorSpinorParam csParam(x[0]); csParam.create = QUDA_ZERO_FIELD_CREATE; // Full precision variables. - r_full = ColorSpinorField(csParam); + resize(r_full, b.size(), csParam); // Create temporary. - y = ColorSpinorField(csParam); + resize(y, b.size(), csParam); // Sloppy precision variables. csParam.setPrecision(param.precision_sloppy); // Sloppy solution. if (!mixed() || !param.use_sloppy_partial_accumulator) { - x_sloppy = x.create_alias(); // x_sloppy and x point to the same vector in memory. + create_alias(x_sloppy, x); // x_sloppy and x point to the same vector in memory. } else { - x_sloppy = ColorSpinorField(csParam); + resize(x_sloppy, b.size(), csParam); } // Shadow residual. if (!mixed() && param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - r0 = const_cast(b).create_alias(); + create_alias(r0, b); } else { - r0 = ColorSpinorField(csParam); + resize(r0, b.size(), csParam); } // Temporary - temp = ColorSpinorField(csParam); + resize(temp, b.size(), csParam); // Residual (+ extra residuals for BiCG steps), Search directions. // Remark: search directions are sloppy in GCR. I wonder if we can // get away with that here. + r.resize(n_krylov + 1); + u.resize(n_krylov + 1); for (int i = 0; i <= n_krylov; i++) { - r[i] = (i > 0 || mixed()) ? ColorSpinorField(csParam) : r_full.create_alias(); - u[i] = ColorSpinorField(csParam); + if (i > 0 || mixed()) + resize(r[i], b.size(), csParam); + else + create_alias(r[i], r_full); + resize(u[i], b.size(), csParam); + } + + alpha.resize(b.size(), 0.0); + beta.resize(b.size()); + omega.resize(b.size(), 1.0); + rho0.resize(b.size(), 1.0); + rho1.resize(b.size()); + + gamma.resize(b.size()); + gamma_prime.resize(b.size()); + gamma_prime_prime.resize(b.size()); + sigma.resize(b.size()); + tau.resize(b.size()); + for (auto i = 0u; i < b.size(); i++) { + gamma[i].resize(n_krylov + 1); + gamma_prime[i].resize(n_krylov + 1); + gamma_prime_prime[i].resize(n_krylov + 1); + sigma[i].resize(n_krylov + 1); + tau[i].resize((n_krylov + 1) * (n_krylov + 1)); } getProfile().TPSTOP(QUDA_PROFILE_INIT); @@ -434,7 +465,7 @@ namespace quda { } } - void BiCGstabL::operator()(ColorSpinorField &x, ColorSpinorField &b) + void BiCGstabL::operator()(cvector_ref &x, cvector_ref &b) { // BiCGstab-l is based on the algorithm outlined in // BICGSTAB(L) FOR LINEAR EQUATIONS INVOLVING UNSYMMETRIC MATRICES WITH COMPLEX SPECTRUM @@ -449,16 +480,16 @@ namespace quda { // compute b2, but only if we need to bool fixed_iteration = param.sloppy_converge && n_krylov == param.maxiter && !param.compute_true_res; - double b2 = !fixed_iteration ? blas::norm2(b) : 1.0; // norm sq of source. - double r2; // norm sq of residual + auto b2 = !fixed_iteration ? blas::norm2(b) : vector(b.size(), 1.0); // norm sq of source. + vector r2(b.size()); // norm sq of residual if (param.deflate) { // Construct the eigensolver and deflation space if requested. if (param.eig_param.eig_type == QUDA_EIG_TR_LANCZOS || param.eig_param.eig_type == QUDA_EIG_BLK_TR_LANCZOS) { - constructDeflationSpace(b, matMdagM); + constructDeflationSpace(b[0], matMdagM); } else { // Use Arnoldi to inspect the space only and turn off deflation - constructDeflationSpace(b, mat); + constructDeflationSpace(b[0], mat); param.deflate = false; } if (deflate_compute) { @@ -514,19 +545,9 @@ namespace quda { } // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0) { - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - warningQuda("inverting on zero-field source"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; - getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); - return; - } else if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { - b2 = r2; - } else { - errorQuda("Null vector computing requires non-zero guess!"); - } + if (is_zero_src(x, b, b2)) { + getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); + return; } // Set shadow residual depending if the source vector is directly usable @@ -547,15 +568,16 @@ namespace quda { // Initialize values. for (int i = 1; i <= n_krylov; i++) { blas::zero(r[i]); } - rho0 = 1.0; - alpha = 0.0; - omega = 1.0; - - double stop = !fixed_iteration ? stopping(param.tol, b2, param.residual_type) : 0.0; // stopping condition of solver. + auto stop = !fixed_iteration ? stopping(param.tol, b2, param.residual_type) : + vector(b.size(), 0.0); // stopping condition of solver. const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; - double heavy_quark_res = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x, r_full).z) : 0.0; + vector heavy_quark_res(b.size(), 0.0); + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r_full); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual //bool l2_converge = false; @@ -576,27 +598,31 @@ namespace quda { // Various variables related to reliable updates. int rUpdate = 0; // count reliable updates. double delta = param.delta; // delta for reliable updates. - double rNorm = sqrt(r2); // The current residual norm. + double rNorm = sqrt(r2[0]); // The current residual norm. double maxrr = rNorm; // The maximum residual norm since the last reliable update. double maxrx = rNorm; // The same. Would be different if we did 'x' reliable updates. PrintStats(solver_name.c_str(), total_iter, r2, b2, heavy_quark_res); - while (!convergence(r2, 0.0, stop, 0.0) && total_iter < param.maxiter) { + while (!convergenceL2(r2, stop) && total_iter < param.maxiter) { // rho0 = -omega*rho0; - rho0 *= -omega; + for (auto i = 0u; i < b.size(); i++) rho0[i] *= -omega[i]; // BiCG part of calculation. for (int j = 0; j < n_krylov; j++) { // rho1 = , beta = alpha*rho1/rho0, rho0 = rho1; // Can fuse into updateXRend. rho1 = blas::cDotProduct(r0, r[j]); - beta = alpha*rho1/rho0; - rho0 = rho1; + for (auto i = 0u; i < b.size(); i++) { + beta[i] = alpha[i] * rho1[i] / rho0[i]; + rho0[i] = rho1[i]; + } // for i = 0 .. j, u[i] = r[i] - beta*u[i] // All but i = j is hidden in Dslash auxillary work (overlapping comms and compute). - blas::caxpby(1.0, r[j], -beta, u[j]); + std::vector minus_beta(beta.size()); + for (auto i = 0u; i < beta.size(); i++) minus_beta[i] = -beta[i]; + blas::caxpby(1.0, r[j], minus_beta, u[j]); if (j > 0) { dslash::aux_worker = &bicgstabl_update; @@ -611,11 +637,14 @@ namespace quda { // alpha = rho0/ // The machinary isn't there yet, but this could be fused with the matSloppy above. - alpha = rho0/blas::cDotProduct(r0, u[j+1]); + auto r0Tu = blas::cDotProduct(r0, u[j + 1]); + for (auto i = 0u; i < b.size(); i++) alpha[i] = rho0[i] / r0Tu[i]; // for i = 0 .. j, r[i] = r[i] - alpha u[i+1] // All but i = j is hidden in Dslash auxillary work (overlapping comms and compute). - blas::caxpy(-alpha, u[j+1], r[j]); + std::vector minus_alpha(alpha.size()); + for (auto i = 0u; i < alpha.size(); i++) minus_alpha[i] = -alpha[i]; + blas::caxpy(minus_alpha, u[j + 1], r[j]); // We can always at least update x. dslash::aux_worker = &bicgstabl_update; bicgstabl_update.update_j_max(j); @@ -630,11 +659,11 @@ namespace quda { #ifndef LEGACY_MR // Perform the MR portion of BiCGstab-L // if we're doing a fixed number of iterations, we only need to update x - computeMR(x_sloppy, fixed_iteration); + for (auto i = 0u; i < b.size(); i++) computeMR(x_sloppy[i], get_i(u, i), get_i(r, i), fixed_iteration, i); #else // Legacy version matching the BiCGstab-L paper which performs // an explicit Gram-Schmidt for the MR portion - legacyComputeMR(x_sloppy); + for (auto i = 0u; i < b.size(); i++) legacyComputeMR(x_sloppy[i], get_i(u, i), get_i(r, i), i); #endif if (!fixed_iteration) { @@ -644,12 +673,14 @@ namespace quda { // Check the heavy quark residual if we need to. if (use_heavy_quark_res && total_iter % heavy_quark_check == 0) { - if (&x != &x_sloppy) { + if (x.Precision() != x_sloppy[0].Precision()) { blas::copy(temp, y); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x_sloppy, temp, r[0]).z); + auto hq = blas::xpyHeavyQuarkResidualNorm(x_sloppy, temp, r[0]); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); } else { blas::copy(r_full, r[0]); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r_full).z); + auto hq = blas::xpyHeavyQuarkResidualNorm(x, y, r_full); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); } } } @@ -660,7 +691,7 @@ namespace quda { // updated (depending on if you're using pipelining or not). In BiCGstab-L, there's only // one place (for now) to get the updated residual, so we just do away with 'updateR'. // Further remark: "reliable" updates rNorm, maxrr, maxrx!! - if (total_iter >= param.maxiter || r2 < stop || reliable(rNorm, maxrx, maxrr, r2, delta)) { + if (total_iter >= param.maxiter || r2 < stop || reliable(rNorm, maxrx, maxrr, r2[0], delta)) { if ((r2 < stop || total_iter >= param.maxiter) && param.sloppy_converge) break; if (mixed()) { blas::copy(x, x_sloppy); } @@ -681,7 +712,7 @@ namespace quda { blas::zero(x_sloppy); // Update rNorm, maxrr, maxrx. - rNorm = sqrt(r2); + rNorm = sqrt(r2[0]); maxrr = rNorm; maxrx = rNorm; @@ -713,14 +744,17 @@ namespace quda { // !param.is_preconditioner comes from bicgstab, param.compute_true_res came from gcr. if (!param.is_preconditioner && param.compute_true_res) { // do not do the below if this is an inner solver. mat(r_full, x); - double true_res = blas::xmyNorm(b, r_full); - param.true_res = sqrt(true_res / b2); - param.true_res_hq = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x, r[0]).z) : 0.0; + auto true_res = blas::xmyNorm(b, r_full); + auto hq = use_heavy_quark_res ? blas::HeavyQuarkResidualNorm(x, r[0]) : vector(b.size(), {}); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(true_res[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } } getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); - PrintSummary(solver_name.c_str(), total_iter, r2, b2, stop, param.tol_hq); + PrintSummary(solver_name.c_str(), total_iter, r2, b2, stop); } } // namespace quda diff --git a/lib/inv_ca_cg.cpp b/lib/inv_ca_cg.cpp index 02b05c91bf..86b6429532 100644 --- a/lib/inv_ca_cg.cpp +++ b/lib/inv_ca_cg.cpp @@ -27,174 +27,27 @@ namespace quda if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_FREE); } - CACGNE::CACGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m) : - CACG(mmdag, mmdagSloppy, mmdagPrecon, mmdagEig, param), - mmdag(mat.Expose()), - mmdagSloppy(matSloppy.Expose()), - mmdagPrecon(matPrecon.Expose()), - mmdagEig(matEig.Expose()) - { - } - - void CACGNE::create(ColorSpinorField &x, const ColorSpinorField &b) - { - Solver::create(x, b); - if (!init) { - ColorSpinorParam csParam(x); - csParam.create = QUDA_NULL_FIELD_CREATE; - xp = ColorSpinorField(csParam); - csParam.create = QUDA_ZERO_FIELD_CREATE; - yp = ColorSpinorField(csParam); - init = true; - } - } - - ColorSpinorField &CACGNE::get_residual() - { - if (!init) errorQuda("No residual vector present"); - if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); - // CG residual will match the CGNE residual (FIXME: but only with zero initial guess?) - return param.use_init_guess ? xp : CACG::get_residual(); - } - - // CACGNE: M Mdag y = b is solved; x = Mdag y is returned as solution. - void CACGNE::operator()(ColorSpinorField &x, ColorSpinorField &b) - { - if (param.maxiter == 0 || param.Nsteps == 0) { - if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); - return; - } - - create(x, b); - - const int iter0 = param.iter; - double b2 = param.compute_true_res ? blas::norm2(b) : 0.0; - - if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { - // compute initial residual - mmdag.Expose()->M(xp, x); - if (param.compute_true_res && b2 == 0.0) - b2 = blas::xmyNorm(b, xp); - else - blas::xpay(b, -1.0, xp); - - // compute solution to residual equation - CACG::operator()(yp, xp); - - mmdag.Expose()->Mdag(xp, yp); - - // compute full solution - blas::xpy(xp, x); - } else { - CACG::operator()(yp, b); - mmdag.Expose()->Mdag(x, yp); - } - - if (param.compute_true_res || (param.use_init_guess && param.return_residual)) { - // compute the true residual - mmdag.Expose()->M(xp, x); - blas::xpay(b, -1.0, xp); // xp now holds the residual - - double r2; - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - double3 h3 = blas::HeavyQuarkResidualNorm(x, xp); - r2 = h3.y; - param.true_res_hq = sqrt(h3.z); - } else { - r2 = blas::norm2(xp); - } - param.true_res = sqrt(r2 / b2); - PrintSummary("CA-CGNE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); - } - } - - CACGNR::CACGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m) : - CACG(mdagm, mdagmSloppy, mdagmPrecon, mdagmEig, param), - mdagm(mat.Expose()), - mdagmSloppy(matSloppy.Expose()), - mdagmPrecon(matPrecon.Expose()), - mdagmEig(matEig.Expose()) - { - } - - void CACGNR::create(ColorSpinorField &x, const ColorSpinorField &b) + void CACG::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { - ColorSpinorParam csParam(b); - csParam.create = QUDA_ZERO_FIELD_CREATE; - br = ColorSpinorField(csParam); - init = true; - } - } - - ColorSpinorField &CACGNR::get_residual() - { - if (!init) errorQuda("No residual vector present"); - if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); - return br; - } - - // CACGNR: Mdag M x = Mdag b is solved. - void CACGNR::operator()(ColorSpinorField &x, ColorSpinorField &b) - { - if (param.maxiter == 0 || param.Nsteps == 0) { - if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); - return; - } - - create(x, b); - - const int iter0 = param.iter; - double b2 = 0.0; - if (param.compute_true_res) { - b2 = blas::norm2(b); - if (b2 == 0.0) { // compute initial residual vector - mdagm.Expose()->M(br, x); - b2 = blas::norm2(br); - } - } - - mdagm.Expose()->Mdag(br, b); - CACG::operator()(x, br); - - if (param.compute_true_res || param.return_residual) { - // compute the true residual - mdagm.Expose()->M(br, x); - blas::xpay(b, -1.0, br); // br now holds the residual - - if (param.compute_true_res) { - double r2; - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - double3 h3 = blas::HeavyQuarkResidualNorm(x, br); - r2 = h3.y; - param.true_res_hq = sqrt(h3.z); - } else { - r2 = blas::norm2(br); - } - param.true_res = sqrt(r2 / b2); - PrintSummary("CA-CGNR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); - } - } - } - - void CACG::create(ColorSpinorField &x, const ColorSpinorField &b) - { - Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { if (!param.is_preconditioner) getProfile().TPSTART(QUDA_PROFILE_INIT); - Q_AQandg.resize(param.Nkrylov * (param.Nkrylov + 1)); - Q_AS.resize(param.Nkrylov * param.Nkrylov); - alpha.resize(param.Nkrylov); - beta.resize(param.Nkrylov * param.Nkrylov); + Q_AQandg.resize(b.size()); + Q_AS.resize(b.size()); + alpha.resize(b.size()); + beta.resize(b.size()); + for (auto i = 0u; i < b.size(); i++) { + Q_AQandg[i].resize(param.Nkrylov * (param.Nkrylov + 1)); + Q_AS[i].resize(param.Nkrylov * param.Nkrylov); + alpha[i].resize(param.Nkrylov); + beta[i].resize(param.Nkrylov * param.Nkrylov); + } - ColorSpinorParam csParam(b); + ColorSpinorParam csParam(b[0]); csParam.create = QUDA_NULL_FIELD_CREATE; - if (mixed()) r = ColorSpinorField(csParam); + if (mixed()) resize(r, b.size(), csParam); // now allocate sloppy fields csParam.setPrecision(param.precision_sloppy); @@ -205,15 +58,18 @@ namespace quda Qtmp.resize(param.Nkrylov); // only used as an intermediate for pointer swaps S.resize(param.Nkrylov); for (int i = 0; i < param.Nkrylov; i++) { - AS[i] = ColorSpinorField(csParam); - Q[i] = ColorSpinorField(csParam); - AQ[i] = ColorSpinorField(csParam); - Qtmp[i] = ColorSpinorField(csParam); + resize(AS[i], b.size(), csParam); + resize(Q[i], b.size(), csParam); + resize(AQ[i], b.size(), csParam); + resize(Qtmp[i], b.size(), csParam); // in the power basis we can alias AS[k] to S[k+1] - S[i] = (basis == QUDA_POWER_BASIS && i > 0) ? AS[i - 1] : ColorSpinorField(csParam); + if (basis == QUDA_POWER_BASIS && i > 0) + create_alias(S[i], AS[i - 1]); + else + resize(S[i], b.size(), csParam); } - if (!mixed()) r = S[0].create_alias(csParam); + if (!mixed()) create_alias(r, S[0]); if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_INIT); @@ -241,7 +97,7 @@ namespace quda // psi = svd.solve(phi); } - void CACG::compute_alpha() + void CACG::compute_alpha(int b) { if (!param.is_preconditioner) { getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); @@ -251,31 +107,30 @@ namespace quda const int N = Q.size(); switch (N) { #if 0 // since CA-CG is not used anywhere at the moment, no point paying for this compilation cost - case 1: compute_alpha_N<1>(Q_AQandg, alpha); break; - case 2: compute_alpha_N<2>(Q_AQandg, alpha); break; - case 3: compute_alpha_N<3>(Q_AQandg, alpha); break; - case 4: compute_alpha_N<4>(Q_AQandg, alpha); break; - case 5: compute_alpha_N<5>(Q_AQandg, alpha); break; - case 6: compute_alpha_N<6>(Q_AQandg, alpha); break; - case 7: compute_alpha_N<7>(Q_AQandg, alpha); break; - case 8: compute_alpha_N<8>(Q_AQandg, alpha); break; - case 9: compute_alpha_N<9>(Q_AQandg, alpha); break; - case 10: compute_alpha_N<10>(Q_AQandg, alpha); break; - case 11: compute_alpha_N<11>(Q_AQandg, alpha); break; - case 12: compute_alpha_N<12>(Q_AQandg, alpha); break; + case 1: compute_alpha_N<1>(Q_AQandg[b], alpha[b]); break; + case 2: compute_alpha_N<2>(Q_AQandg[b], alpha[b]); break; + case 3: compute_alpha_N<3>(Q_AQandg[b], alpha[b]); break; + case 4: compute_alpha_N<4>(Q_AQandg[b], alpha[b]); break; + case 5: compute_alpha_N<5>(Q_AQandg[b], alpha[b]); break; + case 6: compute_alpha_N<6>(Q_AQandg[b], alpha[b]); break; + case 7: compute_alpha_N<7>(Q_AQandg[b], alpha[b]); break; + case 8: compute_alpha_N<8>(Q_AQandg[b], alpha[b]); break; + case 9: compute_alpha_N<9>(Q_AQandg[b], alpha[b]); break; + case 10: compute_alpha_N<10>(Q_AQandg[b], alpha[b]); break; + case 11: compute_alpha_N<11>(Q_AQandg[b], alpha[b]); break; + case 12: compute_alpha_N<12>(Q_AQandg[b], alpha[b]); break; #endif default: // failsafe typedef Matrix matrix; typedef Matrix vector; - const int N = Q.size(); matrix matQ_AQ(N, N); vector vecg(N); for (int i = 0; i < N; i++) { - vecg(i) = Q_AQandg[i * (N + 1) + N]; - for (int j = 0; j < N; j++) { matQ_AQ(i, j) = Q_AQandg[i * (N + 1) + j]; } + vecg(i) = Q_AQandg[b][i * (N + 1) + N]; + for (int j = 0; j < N; j++) { matQ_AQ(i, j) = Q_AQandg[b][i * (N + 1) + j]; } } - Map vecalpha(alpha.data(), N); + Map vecalpha(alpha[b].data(), N); vecalpha = matQ_AQ.fullPivLu().solve(vecg); @@ -309,7 +164,7 @@ namespace quda // psi = svd.solve(phi); } - void CACG::compute_beta() + void CACG::compute_beta(int b) { if (!param.is_preconditioner) { getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); @@ -319,28 +174,27 @@ namespace quda const int N = Q.size(); switch (N) { #if 0 // since CA-CG is not used anywhere at the moment, no point paying for this compilation cost - case 1: compute_beta_N<1>(Q_AQandg, Q_AS, beta); break; - case 2: compute_beta_N<2>(Q_AQandg, Q_AS, beta); break; - case 3: compute_beta_N<3>(Q_AQandg, Q_AS, beta); break; - case 4: compute_beta_N<4>(Q_AQandg, Q_AS, beta); break; - case 5: compute_beta_N<5>(Q_AQandg, Q_AS, beta); break; - case 6: compute_beta_N<6>(Q_AQandg, Q_AS, beta); break; - case 7: compute_beta_N<7>(Q_AQandg, Q_AS, beta); break; - case 8: compute_beta_N<8>(Q_AQandg, Q_AS, beta); break; - case 9: compute_beta_N<9>(Q_AQandg, Q_AS, beta); break; - case 10: compute_beta_N<10>(Q_AQandg, Q_AS, beta); break; - case 11: compute_beta_N<11>(Q_AQandg, Q_AS, beta); break; - case 12: compute_beta_N<12>(Q_AQandg, Q_AS, beta); break; + case 1: compute_beta_N<1>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 2: compute_beta_N<2>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 3: compute_beta_N<3>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 4: compute_beta_N<4>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 5: compute_beta_N<5>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 6: compute_beta_N<6>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 7: compute_beta_N<7>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 8: compute_beta_N<8>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 9: compute_beta_N<9>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 10: compute_beta_N<10>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 11: compute_beta_N<11>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 12: compute_beta_N<12>(Q_AQandg[b], Q_AS[b], beta[b]); break; #endif default: // failsafe typedef Matrix matrix; - const int N = Q.size(); matrix matQ_AQ(N, N); for (int i = 0; i < N; i++) { - for (int j = 0; j < N; j++) { matQ_AQ(i, j) = Q_AQandg[i * (N + 1) + j]; } + for (int j = 0; j < N; j++) { matQ_AQ(i, j) = Q_AQandg[b][i * (N + 1) + j]; } } - Map matQ_AS(Q_AS.data(), N, N), matbeta(beta.data(), N, N); + Map matQ_AS(Q_AS[b].data(), N, N), matbeta(beta[b].data(), N, N); matQ_AQ = -matQ_AQ; matbeta = matQ_AQ.fullPivLu().solve(matQ_AS); @@ -373,7 +227,7 @@ namespace quda return updateR; } - ColorSpinorField &CACG::get_residual() + cvector_ref CACG::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); @@ -397,7 +251,7 @@ namespace quda 2. Steepest descent minmization of the residual in this basis 3. Update solution and residual vectors */ - void CACG::operator()(ColorSpinorField &x, ColorSpinorField &b) + void CACG::operator()(cvector_ref &x, cvector_ref &b) { if (param.is_preconditioner) commGlobalReductionPush(param.global_reduction); @@ -414,12 +268,12 @@ namespace quda // compute b2, but only if we need to bool fixed_iteration = param.sloppy_converge && n_krylov == param.maxiter && !param.compute_true_res; - double b2 = !fixed_iteration ? blas::norm2(b) : 1.0; - double r2 = 0.0; // if zero source then we will exit immediately doing no work + auto b2 = !fixed_iteration ? blas::norm2(b) : vector(b.size(), 1.0); + vector r2(b.size(), 0.0); // if zero source then we will exit immediately doing no work if (param.deflate) { // Construct the eigensolver and deflation space. - constructDeflationSpace(b, matEig); + constructDeflationSpace(b[0], matEig); if (deflate_compute) { // compute the deflation space. if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); @@ -449,6 +303,12 @@ namespace quda blas::zero(x); } + // Check to see that we're not trying to invert on a zero-field source + if (is_zero_src(x, b, b2)) { + if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); + return; + } + if (param.deflate && param.maxiter > 1) { // Deflate and add solution to accumulator eig_solve->deflate(x, r, evecs, evals, true); @@ -474,7 +334,7 @@ namespace quda // Perform 100 power iterations, normalizing every 10 mat-vecs, using r as an initial seed // and Q[0]/AQ[0] as temporaries for the power iterations - lambda_max = 1.1 * Solver::performPowerIterations(matSloppy, r, Q[0], AQ[0], 100, 10); + lambda_max = 1.1 * Solver::performPowerIterations(matSloppy, r[0], Q[0][0], AQ[0][0], 100, 10); logQuda(QUDA_SUMMARIZE, "CA-CG Approximate lambda max = 1.1 x %e\n", lambda_max / 1.1); lambda_init = true; @@ -485,20 +345,8 @@ namespace quda } } - // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0) { - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - warningQuda("inverting on zero-field source\n"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; - return; - } else { - b2 = r2; - } - } - - double stop = !fixed_iteration ? stopping(param.tol, b2, param.residual_type) : 0.0; // stopping condition of solver + auto stop = !fixed_iteration ? stopping(param.tol, b2, param.residual_type) : + vector(b2.size(), 0.0); // stopping condition of solver const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; @@ -508,8 +356,11 @@ namespace quda const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance const int maxResIncreaseTotal = param.max_res_increase_total; - double heavy_quark_res = 0.0; // heavy quark residual - if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + vector heavy_quark_res(b.size(), 0.0); // heavy quark residual + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } int resIncrease = 0; int resIncreaseTotal = 0; @@ -519,13 +370,13 @@ namespace quda getProfile().TPSTART(QUDA_PROFILE_COMPUTE); } int total_iter = 0; - double r2_old = r2; + auto r2_old = r2; bool l2_converge = false; // Various variables related to reliable updates. int rUpdate = 0; // count reliable updates. double delta = param.delta; // delta for reliable updates. - double rNorm = sqrt(r2); // The current residual norm. + double rNorm = sqrt(r2[0]); // The current residual norm. double maxrr = rNorm; // The maximum residual norm since the last reliable update. double maxr_deflate = rNorm; // The maximum residual since the last deflation @@ -533,10 +384,17 @@ namespace quda double m_map = 2. / (lambda_max - lambda_min); double b_map = -(lambda_max + lambda_min) / (lambda_max - lambda_min); + auto get_i = [](std::vector> &p, int i) { + vector_ref p_i; + p_i.reserve(p.size()); + for (auto &pi : p) p_i.push_back(pi[i]); + return p_i; + }; + blas::copy(S[0], r); // no op if uni-precision PrintStats("CA-CG", total_iter, r2, b2, heavy_quark_res); - while (!convergence(r2, heavy_quark_res, stop, param.tol_hq) && total_iter < param.maxiter) { + while (!convergenceL2(r2, stop) && total_iter < param.maxiter) { // build up a space of size n_krylov, assumes S[0] is in place computeCAKrylovSpace(matSloppy, AS, S, n_krylov, basis, m_map, b_map); @@ -554,14 +412,28 @@ namespace quda // Compute the beta coefficients for updating Q, AQ // 1. compute matrix Q_AS = -Q^\dagger AS // 2. Solve Q_AQ beta = Q_AS - blas::block::reDotProduct(Q_AS, AQ, S); - compute_beta(); + for (auto i = 0u; i < b.size(); i++) { + auto AQi = get_i(AQ, i); + auto Si = get_i(S, i); + blas::block::reDotProduct(Q_AS[i], AQi, Si); + compute_beta(i); + } // update direction vectors - blas::block::axpyz(beta, Q, S, Qtmp); + for (auto i = 0u; i < b.size(); i++) { + auto Qi = get_i(Q, i); + auto Si = get_i(S, i); + auto Qtmpi = get_i(Qtmp, i); + blas::block::axpyz(beta[i], Qi, Si, Qtmpi); + } for (int i = 0; i < n_krylov; i++) std::swap(Q[i], Qtmp[i]); - blas::block::axpyz(beta, AQ, AS, Qtmp); + for (auto i = 0u; i < b.size(); i++) { + auto AQi = get_i(AQ, i); + auto ASi = get_i(AS, i); + auto Qtmpi = get_i(Qtmp, i); + blas::block::axpyz(beta[i], AQi, ASi, Qtmpi); + } for (int i = 0; i < n_krylov; i++) std::swap(AQ[i], Qtmp[i]); } @@ -569,19 +441,34 @@ namespace quda // 1. Compute Q_AQ = Q^\dagger AQ and g = Q^dagger r = Q^dagger S[0] // 2. Solve Q_AQ alpha = g { - blas::block::reDotProduct(Q_AQandg, Q, {AQ, S[0]}); - compute_alpha(); + for (auto i = 0u; i < b.size(); i++) { + auto Qi = get_i(Q, i); + auto AQi = get_i(AQ, i); + auto Si = get_i(S, i); + blas::block::reDotProduct(Q_AQandg[i], Qi, {AQi, Si[0]}); + compute_alpha(i); + } } // update the solution vector - blas::block::axpy(alpha, Q, x); + for (auto i = 0u; i < b.size(); i++) { + auto Qi = get_i(Q, i); + blas::block::axpy(alpha[i], Qi, x[i]); + } - for (int i = 0; i < param.Nkrylov; i++) { alpha[i] = -alpha[i]; } + for (auto i = 0u; i < b.size(); i++) + for (auto j = 0; j < n_krylov; j++) { alpha[i][j] = -alpha[i][j]; } // Can we fuse these? We don't need this reduce in all cases... - blas::block::axpy(alpha, AQ, S[0]); + for (auto i = 0u; i < b.size(); i++) { + auto AQi = get_i(AQ, i); + auto Si = get_i(S, i); + blas::block::axpy(alpha[i], AQi, Si[0]); + } // if (getVerbosity() >= QUDA_VERBOSE) r2 = blas::norm2(S[0]); - /*else*/ r2 = Q_AQandg[param.Nkrylov]; // actually the old r2... so we do one more iter than needed... + /*else*/ + // actually the old r2... so we do one more iter than needed... + for (auto i = 0u; i < r2.size(); i++) r2[i] = Q_AQandg[i][param.Nkrylov]; } else { // fixed iterations // On the first pass, Q = S; AQ = AQ. We can just skip that. @@ -591,14 +478,21 @@ namespace quda // We do compute the alpha coefficients: this is the same code as above // 1. Compute "Q_AQ" = S^\dagger AS and g = S^dagger r = S^dagger S[0] // 2. Solve "Q_AQ" alpha = g - blas::block::reDotProduct(Q_AQandg, S, {AS, S[0]}); - compute_alpha(); + for (auto i = 0u; i < b.size(); i++) { + auto Si = get_i(S, i); + auto ASi = get_i(AS, i); + blas::block::reDotProduct(Q_AQandg[i], Si, {ASi, Si[0]}); + compute_alpha(i); + } // update the solution vector - blas::block::axpy(alpha, S, x); + for (auto i = 0u; i < b.size(); i++) { + auto Si = get_i(S, i); + blas::block::axpy(alpha[i], Si, x[i]); + } // no need to update AS - r2 = Q_AQandg[param.Nkrylov]; // actually the old r2... so we do one more iter than needed... + for (auto i = 0u; i < r2.size(); i++) r2[i] = Q_AQandg[i][param.Nkrylov]; } // NOTE: Because we always carry around the residual from an iteration before, we @@ -611,13 +505,13 @@ namespace quda // update since n_krylov or maxiter reached, converged or reliable update required // note that the heavy quark residual will by definition only be checked every n_krylov steps // Note: this won't reliable update when the norm _increases_. - if (total_iter >= param.maxiter || (r2 < stop && !l2_converge) || reliable(rNorm, maxrr, rUpdate, r2, delta)) { + if (total_iter >= param.maxiter || (r2 < stop && !l2_converge) || reliable(rNorm, maxrr, rUpdate, r2[0], delta)) { if ((r2 < stop || total_iter >= param.maxiter) && param.sloppy_converge) break; mat(r, x); r2 = blas::xmyNorm(b, r); - if (param.deflate && sqrt(r2) < maxr_deflate * param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < maxr_deflate * param.tol_restart) { // Deflate and add solution to accumulator eig_solve->deflate(x, r, evecs, evals, true); @@ -625,12 +519,15 @@ namespace quda mat(r, x); r2 = blas::xmyNorm(b, r); - maxr_deflate = sqrt(r2); + maxr_deflate = sqrt(r2[0]); } blas::copy(S[0], r); - if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } // break-out check if we have reached the limit of the precision if (r2 > r2_old) { @@ -638,7 +535,7 @@ namespace quda resIncreaseTotal++; warningQuda( "CA-CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), sqrt(r2_old), resIncreaseTotal); + sqrt(r2[0]), sqrt(r2_old[0]), resIncreaseTotal); if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) { warningQuda("CA-CG: solver exiting due to too many true residual norm increases"); break; @@ -660,10 +557,12 @@ namespace quda if (param.compute_true_res) { // Calculate the true residual mat(r, x); - double true_res = blas::xmyNorm(b, r); - param.true_res = sqrt(true_res / b2); - param.true_res_hq - = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? sqrt(blas::HeavyQuarkResidualNorm(x, r).z) : 0.0; + auto true_res = blas::xmyNorm(b, r); + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(true_res[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } } if (!param.is_preconditioner) { @@ -671,7 +570,7 @@ namespace quda param.iter += total_iter; } - PrintSummary("CA-CG", total_iter, r2, b2, stop, param.tol_hq); + PrintSummary("CA-CG", total_iter, r2, b2, stop); if (param.is_preconditioner) commGlobalReductionPop(); } diff --git a/lib/inv_ca_gcr.cpp b/lib/inv_ca_gcr.cpp index 4610df455e..aa41c9409b 100644 --- a/lib/inv_ca_gcr.cpp +++ b/lib/inv_ca_gcr.cpp @@ -21,18 +21,19 @@ namespace quda destroyDeflationSpace(); } - void CAGCR::create(ColorSpinorField &x, const ColorSpinorField &b) + void CAGCR::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { if (!param.is_preconditioner) getProfile().TPSTART(QUDA_PROFILE_INIT); - alpha.resize(param.Nkrylov); + alpha.resize(b.size()); + for (auto i = 0u; i < alpha.size(); i++) alpha[i].resize(param.Nkrylov); bool mixed = param.precision != param.precision_sloppy; - ColorSpinorParam csParam(b); + ColorSpinorParam csParam(b[0]); csParam.create = QUDA_NULL_FIELD_CREATE; csParam.setPrecision(param.precision_sloppy); @@ -41,27 +42,30 @@ namespace quda p.resize(param.Nkrylov + 1); q.resize(param.Nkrylov); for (int i = 0; i < param.Nkrylov + 1; i++) { - p[i] = ColorSpinorField(csParam); - if (i > 0) q[i - 1] = p[i].create_alias(csParam); + resize(p[i], b.size(), csParam); + if (i > 0) create_alias(q[i - 1], p[i]); } } else { p.resize(param.Nkrylov); q.resize(param.Nkrylov); for (int i = 0; i < param.Nkrylov; i++) { - p[i] = ColorSpinorField(csParam); - q[i] = ColorSpinorField(csParam); + resize(p[i], b.size(), csParam); + resize(q[i], b.size(), csParam); } } csParam.setPrecision(param.precision); - r = mixed ? ColorSpinorField(csParam) : p[0].create_alias(csParam); + if (mixed) + resize(r, b.size(), csParam); + else + create_alias(r, p[0]); if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_INIT); init = true; } // init } - void CAGCR::solve(std::vector &psi_, std::vector &q, ColorSpinorField &b) + void CAGCR::solve(std::vector &psi_, cvector_ref &q, ColorSpinorField &b) { typedef Matrix matrix; typedef Matrix vector; @@ -113,7 +117,7 @@ namespace quda } } - ColorSpinorField &CAGCR::get_residual() + cvector_ref CAGCR::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); @@ -127,7 +131,7 @@ namespace quda 3. Update solution and residual vectors 4. (Optional) restart if convergence or maxiter not reached */ - void CAGCR::operator()(ColorSpinorField &x, ColorSpinorField &b) + void CAGCR::operator()(cvector_ref &x, cvector_ref &b) { const int n_krylov = param.Nkrylov; @@ -143,16 +147,16 @@ namespace quda // compute b2, but only if we need to bool fixed_iteration = param.sloppy_converge && n_krylov == param.maxiter && !param.compute_true_res; - double b2 = !fixed_iteration ? blas::norm2(b) : 1.0; - double r2 = 0.0; // if zero source then we will exit immediately doing no work + auto b2 = !fixed_iteration ? blas::norm2(b) : vector(b.size(), 1.0); + std::vector r2(b.size(), 0.0); // if zero source then we will exit immediately doing no work if (param.deflate) { // Construct the eigensolver and deflation space if requested. if (param.eig_param.eig_type == QUDA_EIG_TR_LANCZOS || param.eig_param.eig_type == QUDA_EIG_BLK_TR_LANCZOS) { - constructDeflationSpace(b, matMdagM); + constructDeflationSpace(b[0], matMdagM); } else { // Use Arnoldi to inspect the space only and turn off deflation - constructDeflationSpace(b, mat); + constructDeflationSpace(b[0], mat); param.deflate = false; } if (deflate_compute) { @@ -217,7 +221,7 @@ namespace quda // Perform 100 power iterations, normalizing every 10 mat-vecs, using r_ as an initial seed // and q[0]/q[1] as temporaries for the power iterations. Technically illegal if n_krylov == 1, but in that case lambda_max isn't used anyway. - lambda_max = 1.1 * Solver::performPowerIterations(matSloppy, r, q[0], q[1], 100, 10); + lambda_max = 1.1 * Solver::performPowerIterations(matSloppy, r[0], q[0][0], q[1][0], 100, 10); logQuda(QUDA_SUMMARIZE, "CA-GCR Approximate lambda max = 1.1 x %e\n", lambda_max / 1.1); lambda_init = true; @@ -233,19 +237,13 @@ namespace quda double b_map = -(lambda_max + lambda_min) / (lambda_max - lambda_min); // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0) { - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - warningQuda("inverting on zero-field source\n"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; - return; - } else { - b2 = r2; - } + if (is_zero_src(x, b, b2)) { + getProfile().TPSTOP(QUDA_PROFILE_INIT); + return; } - double stop = !fixed_iteration ? stopping(param.tol, b2, param.residual_type) : 0.0; // stopping condition of solver + auto stop = !fixed_iteration ? stopping(param.tol, b2, param.residual_type) : + std::vector(b.size(), 0.0); // stopping condition of solver const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; @@ -255,8 +253,11 @@ namespace quda const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance const int maxResIncreaseTotal = param.max_res_increase_total; - double heavy_quark_res = 0.0; // heavy quark residual - if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + std::vector heavy_quark_res(b.size(), 0.0); // heavy quark residual + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } int resIncrease = 0; int resIncreaseTotal = 0; @@ -267,29 +268,41 @@ namespace quda } int total_iter = 0; int restart = 0; - double r2_old = r2; - double maxr_deflate = sqrt(r2); + auto r2_old = r2; + auto maxr_deflate = sqrt(r2[0]); bool l2_converge = false; blas::copy(p[0], r); // no op if uni-precision + auto get_i = [](std::vector> &p, int i) { + vector_ref p_i; + p_i.reserve(p.size()); + for (auto &pi : p) p_i.push_back(pi[i]); + return p_i; + }; + PrintStats("CA-GCR", total_iter, r2, b2, heavy_quark_res); - while (!convergence(r2, heavy_quark_res, stop, param.tol_hq) && total_iter < param.maxiter) { + while (!convergenceL2(r2, stop) && total_iter < param.maxiter) { // build up a space of size n_krylov computeCAKrylovSpace(matSloppy, q, p, n_krylov, basis, m_map, b_map); - solve(alpha, q, p[0]); + for (auto i = 0u; i < b.size(); i++) solve(alpha[i], get_i(q, i), p[0][i]); // need to make sure P is only length n_krylov - blas::block::caxpy(alpha, {p.begin(), p.begin() + n_krylov}, {x}); + for (auto i = 0u; i < b.size(); i++) { + auto pi = get_i(p, i); + blas::block::caxpy(alpha[i], {pi.begin(), pi.begin() + n_krylov}, {x[i]}); + } // no need to compute residual vector if not returning // residual vector and only doing a single fixed iteration if (!fixed_iteration || param.return_residual) { - // update the residual vector - for (int i = 0; i < n_krylov; i++) alpha[i] = -alpha[i]; - blas::block::caxpy(alpha, q, r); + for (auto i = 0u; i < b.size(); i++) { + // update the residual vector + for (int j = 0; j < n_krylov; j++) alpha[i][j] = -alpha[i][j]; + blas::block::caxpy(alpha[i], get_i(q, i), r[i]); + } } total_iter += n_krylov; @@ -302,13 +315,13 @@ namespace quda // update since n_krylov or maxiter reached, converged or reliable update required // note that the heavy quark residual will by definition only be checked every n_krylov steps - if (total_iter >= param.maxiter || (r2 < stop && !l2_converge) || sqrt(r2 / r2_old) < param.delta) { + if (total_iter >= param.maxiter || (r2 < stop && !l2_converge) || sqrt(r2[0] / r2_old[0]) < param.delta) { if ((r2 < stop || total_iter >= param.maxiter) && param.sloppy_converge) break; mat(r, x); r2 = blas::xmyNorm(b, r); - if (param.deflate && sqrt(r2) < maxr_deflate * param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < maxr_deflate * param.tol_restart) { // Deflate and accumulate to solution vector eig_solve->deflateSVD(x, r, evecs, evals, true); @@ -316,10 +329,13 @@ namespace quda mat(r, x); r2 = blas::xmyNorm(b, r); - maxr_deflate = sqrt(r2); + maxr_deflate = sqrt(r2[0]); } - if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } // break-out check if we have reached the limit of the precision if (r2 > r2_old) { @@ -327,7 +343,7 @@ namespace quda resIncreaseTotal++; warningQuda( "CA-GCR: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), sqrt(r2_old), resIncreaseTotal); + sqrt(r2[0]), sqrt(r2_old[0]), resIncreaseTotal); if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) { warningQuda("CA-GCR: solver exiting due to too many true residual norm increases"); break; @@ -340,7 +356,7 @@ namespace quda } // No matter what, if we haven't converged, we do a restart. - if (!convergence(r2, heavy_quark_res, stop, param.tol_hq)) { + if (!convergenceL2(r2, stop)) { restart++; // restarting if residual is still too great PrintStats("CA-GCR (restart)", restart, r2, b2, heavy_quark_res); @@ -361,10 +377,12 @@ namespace quda if (param.compute_true_res) { // Calculate the true residual mat(r, x); - double true_res = blas::xmyNorm(b, r); - param.true_res = sqrt(true_res / b2); - param.true_res_hq - = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? sqrt(blas::HeavyQuarkResidualNorm(x, r).z) : 0.0; + auto true_r2 = blas::xmyNorm(b, r); + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(true_r2[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } } if (!param.is_preconditioner) { @@ -372,7 +390,7 @@ namespace quda param.iter += total_iter; } - PrintSummary("CA-GCR", total_iter, r2, b2, stop, param.tol_hq); + PrintSummary("CA-GCR", total_iter, r2, b2, stop); if (param.is_preconditioner) commGlobalReductionPop(); } diff --git a/lib/inv_cg3_quda.cpp b/lib/inv_cg3_quda.cpp index 30b9fff86c..75d6cc93d9 100644 --- a/lib/inv_cg3_quda.cpp +++ b/lib/inv_cg3_quda.cpp @@ -15,175 +15,30 @@ namespace quda { { } - CG3NE::CG3NE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m) : - CG3(mmdag, mmdagSloppy, mmdagPrecon, param), - mmdag(mat.Expose()), - mmdagSloppy(matSloppy.Expose()), - mmdagPrecon(matPrecon.Expose()) - { - } - - void CG3NE::create(ColorSpinorField &x, const ColorSpinorField &b) - { - Solver::create(x, b); - if (!init) { - ColorSpinorParam csParam(x); - csParam.create = QUDA_NULL_FIELD_CREATE; - xp = ColorSpinorField(csParam); - csParam.create = QUDA_ZERO_FIELD_CREATE; - yp = ColorSpinorField(csParam); - init = true; - } - } - - ColorSpinorField &CG3NE::get_residual() - { - if (!init) errorQuda("No residual vector present"); - if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); - // CG3 residual will match the CG3NE residual (FIXME: but only with zero initial guess?) - return param.use_init_guess ? xp : CG3::get_residual(); - } - - // CG3NE: M Mdag y = b is solved; x = Mdag y is returned as solution. - void CG3NE::operator()(ColorSpinorField &x, ColorSpinorField &b) - { - if (param.maxiter == 0 || param.Nsteps == 0) { - if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); - return; - } - - create(x, b); - - const int iter0 = param.iter; - double b2 = param.compute_true_res ? blas::norm2(b) : 0.0; - - if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { - // compute initial residual - mmdag.Expose()->M(xp, x); - if (param.compute_true_res && b2 == 0.0) - b2 = blas::xmyNorm(b, xp); - else - blas::xpay(b, -1.0, xp); - - // compute solution to residual equation - CG3::operator()(yp, xp); - - mmdag.Expose()->Mdag(xp, yp); - - // compute full solution - blas::xpy(xp, x); - } else { - CG3::operator()(yp, b); - mmdag.Expose()->Mdag(x, yp); - } - - if (param.compute_true_res || (param.use_init_guess && param.return_residual)) { - // compute the true residual - mmdag.Expose()->M(xp, x); - blas::xpay(b, -1.0, xp); // xp now holds the residual - - double r2; - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - double3 h3 = blas::HeavyQuarkResidualNorm(x, xp); - r2 = h3.y; - param.true_res_hq = sqrt(h3.z); - } else { - r2 = blas::norm2(xp); - } - param.true_res = sqrt(r2 / b2); - PrintSummary("CG3NE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); - } - } - - CG3NR::CG3NR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m) : - CG3(mdagm, mdagmSloppy, mdagmPrecon, param), - mdagm(mat.Expose()), - mdagmSloppy(matSloppy.Expose()), - mdagmPrecon(matPrecon.Expose()) - { - } - - void CG3NR::create(ColorSpinorField &x, const ColorSpinorField &b) + void CG3::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { - ColorSpinorParam csParam(b); + if (!init || r.size() != b.size()) { + ColorSpinorParam csParam(b[0]); csParam.create = QUDA_ZERO_FIELD_CREATE; - br = ColorSpinorField(csParam); - init = true; - } - } - - ColorSpinorField &CG3NR::get_residual() - { - if (!init) errorQuda("No residual vector present"); - if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); - return br; - } - - // CG3NR: Mdag M x = Mdag b is solved. - void CG3NR::operator()(ColorSpinorField &x, ColorSpinorField &b) - { - if (param.maxiter == 0 || param.Nsteps == 0) { - if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); - return; - } - - create(x, b); - - const int iter0 = param.iter; - double b2 = 0.0; - if (param.compute_true_res) { - b2 = blas::norm2(b); - if (b2 == 0.0) { // compute initial residual vector - mdagm.Expose()->M(br, x); - b2 = blas::norm2(br); - } - } - - mdagm.Expose()->Mdag(br, b); - CG3::operator()(x, br); - - if (param.compute_true_res || param.return_residual) { - // compute the true residual - mdagm.Expose()->M(br, x); - blas::xpay(b, -1.0, br); // br now holds the residual - - if (param.compute_true_res) { - double r2; - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - double3 h3 = blas::HeavyQuarkResidualNorm(x, br); - r2 = h3.y; - param.true_res_hq = sqrt(h3.z); - } else { - r2 = blas::norm2(br); - } - param.true_res = sqrt(r2 / b2); - PrintSummary("CG3NR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); - } - } - } - - void CG3::create(ColorSpinorField &x, const ColorSpinorField &b) - { - Solver::create(x, b); - - if (!init) { - ColorSpinorParam csParam(b); - csParam.create = QUDA_ZERO_FIELD_CREATE; - r = ColorSpinorField(csParam); - y = ColorSpinorField(csParam); + resize(r, b.size(), csParam); + resize(y, b.size(), csParam); // Sloppy fields - const bool mixed_precision = (param.precision != param.precision_sloppy); csParam.setPrecision(param.precision_sloppy); - ArS = ColorSpinorField(csParam); - rS_old = ColorSpinorField(csParam); - rS = mixed_precision ? ColorSpinorField(csParam) : r.create_alias(); - xS = mixed_precision ? ColorSpinorField(csParam) : x.create_alias(); - xS_old = mixed_precision ? ColorSpinorField(csParam) : y.create_alias(); - tmp = ColorSpinorField(csParam); + resize(ArS, b.size(), csParam); + resize(rS_old, b.size(), csParam); + if (param.precision != param.precision_sloppy) { + resize(rS, b.size(), csParam); + resize(xS, b.size(), csParam); + resize(xS_old, b.size(), csParam); + } else { + create_alias(rS, r); + create_alias(xS, x); + create_alias(xS_old, y); + } + resize(tmp, b.size(), csParam); init = true; } @@ -192,32 +47,28 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - ColorSpinorField &CG3::get_residual() + cvector_ref CG3::get_residual() { if (!init) errorQuda("No residual vector present"); return r; } - void CG3::operator()(ColorSpinorField &x, ColorSpinorField &b) + void CG3::operator()(cvector_ref &x, cvector_ref &b) { getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); // Check to see that we're not trying to invert on a zero-field source - double b2 = blas::norm2(b); - if (b2 == 0 - && (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO || param.use_init_guess == QUDA_USE_INIT_GUESS_NO)) { + auto b2 = blas::norm2(b); + if (is_zero_src(x, b, b2)) { getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); - printfQuda("Warning: inverting on zero-field source\n"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; return; } const bool mixed_precision = (param.precision != param.precision_sloppy); create(x, b); - double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop_hq = vector(b.size(), param.tol_hq); const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; @@ -233,17 +84,18 @@ namespace quda { // these are only used if we use the heavy_quark_res const int hqmaxresIncrease = maxResIncrease + 1; int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual - double heavy_quark_res = 0.0; // heavy quark residual - double heavy_quark_res_old = 0.0; // heavy quark residual + vector heavy_quark_res(b.size(), 0.0); // heavy quark residual + vector heavy_quark_res_old(b.size(), 0.0); // heavy quark residual int hqresIncrease = 0; bool L2breakdown = false; // compute initial residual depending on whether we have an initial guess or not - double r2; + vector r2; if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { mat(r, x); r2 = blas::xmyNorm(b, r); - if (b2 == 0) b2 = r2; + for (auto i = 0u; i < b.size(); i++) + if (b2[i] == 0) b2[i] = r2[i]; if (mixed_precision) { blas::copy(y, x); blas::zero(xS); @@ -260,16 +112,17 @@ namespace quda { blas::copy(rS, r); if (use_heavy_quark_res) { - heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < hq.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); heavy_quark_res_old = heavy_quark_res; } getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); - if (convergence(r2, heavy_quark_res, stop, param.tol_hq)) return; + if (convergence(r2, heavy_quark_res, stop, stop_hq)) return; getProfile().TPSTART(QUDA_PROFILE_COMPUTE); - double r2_old = r2; - double rNorm = sqrt(r2); + auto r2_old = r2; + double rNorm = sqrt(r2[0]); double r0Norm = rNorm; double maxrx = rNorm; double maxrr = rNorm; @@ -278,21 +131,23 @@ namespace quda { int k = 0; PrintStats("CG3", k, r2, b2, heavy_quark_res); - double rho = 1.0, gamma = 1.0; + vector rho(b.size(), 1.0); + vector gamma(b.size(), 1.0); - while ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) && k < param.maxiter) { + while (!convergence(r2, heavy_quark_res, stop, stop_hq) && k < param.maxiter) { matSloppy(ArS, rS); - double gamma_old = gamma; - double rAr = blas::reDotProduct(rS,ArS); - gamma = r2 / rAr; + auto gamma_old = gamma; + auto rAr = blas::reDotProduct(rS, ArS); + for (auto i = 0u; i < b.size(); i++) gamma[i] = r2[i] / rAr[i]; // CG3 step if (k == 0 || restart) { // First iteration r2 = blas::quadrupleCG3InitNorm(gamma, xS, rS, xS_old, rS_old, ArS); restart = false; } else { - rho = rho/(rho-(gamma/gamma_old)*(r2/r2_old)); + for (auto i = 0u; i < rho.size(); i++) + rho[i] = rho[i] / (rho[i] - (gamma[i] / gamma_old[i]) * (r2[i] / r2_old[i])); r2_old = r2; r2 = blas::quadrupleCG3UpdateNorm(gamma, rho, xS, rS, xS_old, rS_old, ArS); } @@ -302,25 +157,27 @@ namespace quda { if (use_heavy_quark_res && k % heavy_quark_check == 0) { heavy_quark_res_old = heavy_quark_res; if (mixed_precision) { - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xS, y, rS).z); + auto hq = blas::xpyHeavyQuarkResidualNorm(xS, y, rS); + for (auto i = 0u; i < b2.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); } else { - heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(xS, rS).z); + auto hq = blas::HeavyQuarkResidualNorm(xS, rS); + for (auto i = 0u; i < b2.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); } } // reliable update conditions if (mixed_precision) { - rNorm = sqrt(r2); + rNorm = sqrt(r2[0]); if (rNorm > maxrx) maxrx = rNorm; if (rNorm > maxrr) maxrr = rNorm; bool update = (rNorm < delta*r0Norm && r0Norm <= maxrx); // condition for x update = ( update || (rNorm < delta*maxrr && r0Norm <= maxrr)); // condition for r // force a reliable update if we are within target tolerance (only if doing reliable updates) - if (convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol) update = true; + if (convergence(r2, heavy_quark_res, stop, stop_hq) && param.delta >= param.tol) update = true; // For heavy-quark inversion force a reliable update if we continue after - if ( use_heavy_quark_res and L2breakdown and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) and param.delta >= param.tol ) { + if (use_heavy_quark_res and L2breakdown and convergenceHQ(heavy_quark_res, stop_hq) and param.delta >= param.tol) { update = true; } @@ -330,33 +187,35 @@ namespace quda { blas::xpy(x, y); mat(r, y); r2 = blas::xmyNorm(b, r); - param.true_res = sqrt(r2 / b2); + for (auto i = 0u; i < b2.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); if (use_heavy_quark_res) { - heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z); + auto hq = blas::HeavyQuarkResidualNorm(y, r); + for (auto i = 0u; i < b2.size(); i++) heavy_quark_res = sqrt(hq[i].z); param.true_res_hq = heavy_quark_res; } - rNorm = sqrt(r2); - r0Norm = sqrt(r2); + rNorm = sqrt(r2[0]); + r0Norm = sqrt(r2[0]); maxrr = rNorm; maxrx = rNorm; // we update sloppy and old fields - if (!convergence(r2, heavy_quark_res, stop, param.tol_hq)) { + if (!convergence(r2, heavy_quark_res, stop, stop_hq)) { blas::copy(rS, r); blas::axpy(-1., xS, xS_old); // we preserve the orthogonality between the previous residual and the new - Complex rr_old = blas::cDotProduct(rS, rS_old); - r2_old = blas::caxpyNorm(-rr_old/r2, rS, rS_old); + auto rr_old = blas::cDotProduct(rS, rS_old); + for (auto i = 0u; i < r2.size(); i++) rr_old[i] /= r2[i]; + r2_old = blas::caxpyNorm(-rr_old, rS, rS_old); blas::zero(xS); } } // break-out check if we have reached the limit of the precision - if (sqrt(r2) > r0Norm) { + if (sqrt(r2[0]) > r0Norm) { resIncrease++; resIncreaseTotal++; warningQuda( "CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), r0Norm, resIncreaseTotal); + sqrt(r2[0]), r0Norm, resIncreaseTotal); if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) { if (use_heavy_quark_res) { L2breakdown = true; @@ -376,9 +235,10 @@ namespace quda { warningQuda("CG3: Restarting without reliable updates for heavy-quark residual"); restart = true; L2breakdown = false; - if (heavy_quark_res > heavy_quark_res_old) { + if (heavy_quark_res[0] > heavy_quark_res_old[0]) { hqresIncrease++; - warningQuda("CG3: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res, heavy_quark_res_old); + warningQuda("CG3: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", + heavy_quark_res[0], heavy_quark_res_old[0]); // break out if we do not improve here anymore if (hqresIncrease > hqmaxresIncrease) { warningQuda("CG3: solver exiting due to too many heavy quark residual norm increases"); @@ -387,25 +247,26 @@ namespace quda { } } } else { - if (convergence(r2, heavy_quark_res, stop, param.tol_hq)) { + if (convergence(r2, heavy_quark_res, stop, stop_hq)) { mat(r, x); r2 = blas::xmyNorm(b, r); - r0Norm = sqrt(r2); + r0Norm = sqrt(r2[0]); // we update sloppy and old fields - if (!convergence(r2, heavy_quark_res, stop, param.tol_hq)) { + if (!convergence(r2, heavy_quark_res, stop, stop_hq)) { // we preserve the orthogonality between the previous residual and the new - Complex rr_old = blas::cDotProduct(rS, rS_old); - r2_old = blas::caxpyNorm(-rr_old/r2, rS, rS_old); + auto rr_old = blas::cDotProduct(rS, rS_old); + for (auto i = 0u; i < r2.size(); i++) rr_old[i] /= r2[i]; + r2_old = blas::caxpyNorm(-rr_old, rS, rS_old); } } // break-out check if we have reached the limit of the precision - if (sqrt(r2) > r0Norm) { + if (sqrt(r2[0]) > r0Norm) { resIncrease++; resIncreaseTotal++; warningQuda( "CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), r0Norm, resIncreaseTotal); + sqrt(r2[0]), r0Norm, resIncreaseTotal); if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) { warningQuda("CG3: solver exiting due to too many true residual norm increases"); break; @@ -427,11 +288,15 @@ namespace quda { // compute the true residuals if (!mixed_precision && param.compute_true_res) { mat(r, x); - param.true_res = sqrt(blas::xmyNorm(b, r) / b2); - if (use_heavy_quark_res) param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + r2 = blas::xmyNorm(b, r); + for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) param.true_res_hq[i] = sqrt(hq[i].z); + } } - PrintSummary("CG3", k, r2, b2, stop, param.tol_hq); + PrintSummary("CG3", k, r2, b2, stop, stop_hq); getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); } diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index 19b1d1225d..43df31e1b2 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -27,186 +27,41 @@ namespace quda { CG::~CG() { destroyDeflationSpace(); } - void CG::create(ColorSpinorField &x, const ColorSpinorField &b) + void CG::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { getProfile().TPSTART(QUDA_PROFILE_INIT); - ColorSpinorParam csParam(x); - csParam.create = QUDA_NULL_FIELD_CREATE; - r = ColorSpinorField(csParam); - y = ColorSpinorField(csParam); + resize(r, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); + resize(y, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); // sloppy fields - csParam.setPrecision(param.precision_sloppy); - p = ColorSpinorField(csParam); - Ap = ColorSpinorField(csParam); - - rSloppy = (r.Precision() != param.precision_sloppy) ? ColorSpinorField(csParam) : r.create_alias(); - param.use_sloppy_partial_accumulator = false; // hard-code precise accumulation - xSloppy = (param.use_sloppy_partial_accumulator == true) ? ColorSpinorField(csParam) : x.create_alias(); - - init = true; - getProfile().TPSTOP(QUDA_PROFILE_INIT); - } - } - - CGNE::CGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m) : - CG(mmdag, mmdagSloppy, mmdagPrecon, mmdagEig, param), - mmdag(mat.Expose()), - mmdagSloppy(matSloppy.Expose()), - mmdagPrecon(matPrecon.Expose()), - mmdagEig(matEig.Expose()) - { - } - - void CGNE::create(ColorSpinorField &x, const ColorSpinorField &b) - { - Solver::create(x, b); - if (!init) { - ColorSpinorParam csParam(x); + ColorSpinorParam csParam(x[0]); csParam.create = QUDA_NULL_FIELD_CREATE; - xp = ColorSpinorField(csParam); - csParam.create = QUDA_ZERO_FIELD_CREATE; - yp = ColorSpinorField(csParam); - init = true; - } - } - - ColorSpinorField &CGNE::get_residual() - { - if (!init) errorQuda("No residual vector present"); - if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); - // CG residual will match the CGNE residual (FIXME: but only with zero initial guess?) - return param.use_init_guess ? xp : CG::get_residual(); - } - - // CGNE: M Mdag y = b is solved; x = Mdag y is returned as solution. - void CGNE::operator()(ColorSpinorField &x, ColorSpinorField &b) - { - if (param.maxiter == 0 || param.Nsteps == 0) { - if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); - return; - } - - create(x, b); - - const int iter0 = param.iter; - double b2 = param.compute_true_res ? blas::norm2(b) : 0.0; - - if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { - // compute initial residual - mmdag.Expose()->M(xp, x); - if (param.compute_true_res && b2 == 0.0) - b2 = blas::xmyNorm(b, xp); - else - blas::xpay(b, -1.0, xp); - - // compute solution to residual equation - CG::operator()(yp, xp); - - mmdag.Expose()->Mdag(xp, yp); - - // compute full solution - blas::xpy(xp, x); - } else { - CG::operator()(yp, b); - mmdag.Expose()->Mdag(x, yp); - } - - if (param.compute_true_res || (param.use_init_guess && param.return_residual)) { - // compute the true residual - mmdag.Expose()->M(xp, x); - blas::xpay(b, -1.0, xp); // xp now holds the residual + csParam.setPrecision(param.precision_sloppy); + resize(p, b.size(), csParam); + resize(Ap, b.size(), csParam); - double r2; - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - double3 h3 = blas::HeavyQuarkResidualNorm(x, xp); - r2 = h3.y; - param.true_res_hq = sqrt(h3.z); + if (param.precision != param.precision_sloppy) { + resize(r_sloppy, b.size(), csParam); } else { - r2 = blas::norm2(xp); + create_alias(r_sloppy, r); } - param.true_res = sqrt(r2 / b2); - PrintSummary("CA-CGNE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); - } - } - - CGNR::CGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m) : - CG(mdagm, mdagmSloppy, mdagmPrecon, mdagmEig, param), - mdagm(mat.Expose()), - mdagmSloppy(matSloppy.Expose()), - mdagmPrecon(matPrecon.Expose()), - mdagmEig(matEig.Expose()) - { - } + param.use_sloppy_partial_accumulator = false; // hard-code precise accumulation + if (param.use_sloppy_partial_accumulator) resize(x_sloppy, b.size(), csParam); - void CGNR::create(ColorSpinorField &x, const ColorSpinorField &b) - { - Solver::create(x, b); - if (!init) { - ColorSpinorParam csParam(b); - csParam.create = QUDA_ZERO_FIELD_CREATE; - br = ColorSpinorField(csParam); init = true; - } - } - - ColorSpinorField &CGNR::get_residual() - { - if (!init) errorQuda("No residual vector present"); - if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); - return br; - } - - // CGNR: Mdag M x = Mdag b is solved. - void CGNR::operator()(ColorSpinorField &x, ColorSpinorField &b) - { - if (param.maxiter == 0 || param.Nsteps == 0) { - if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); - return; - } - - create(x, b); - - const int iter0 = param.iter; - double b2 = 0.0; - if (param.compute_true_res) { - b2 = blas::norm2(b); - if (b2 == 0.0) { // compute initial residual vector - mdagm.Expose()->M(br, x); - b2 = blas::norm2(br); - } + getProfile().TPSTOP(QUDA_PROFILE_INIT); } - mdagm.Expose()->Mdag(br, b); - CG::operator()(x, br); - - if (param.compute_true_res || param.return_residual) { - // compute the true residual - mdagm.Expose()->M(br, x); - blas::xpay(b, -1.0, br); // br now holds the residual - - if (param.compute_true_res) { - double r2; - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - double3 h3 = blas::HeavyQuarkResidualNorm(x, br); - r2 = h3.y; - param.true_res_hq = sqrt(h3.z); - } else { - r2 = blas::norm2(br); - } - param.true_res = sqrt(r2 / b2); - PrintSummary("CGNR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); - } - } + // need to reset x_sloppy every solve + if (!param.use_sloppy_partial_accumulator) create_alias(x_sloppy, x); } - void CG::operator()(ColorSpinorField &x, ColorSpinorField &b, ColorSpinorField *p_init, double r2_old_init) + void CG::operator()(cvector_ref &x, cvector_ref &b, + cvector_ref &p_init, cvector &r2_old_init) { if (param.is_preconditioner) commGlobalReductionPush(param.global_reduction); @@ -216,7 +71,7 @@ namespace quda { } const int Np = (param.solution_accumulator_pipeline == 0 ? 1 : param.solution_accumulator_pipeline); - if (Np < 0 || Np > 16) errorQuda("Invalid value %d for solution_accumulator_pipeline\n", Np); + if (Np < 0 || Np > 16) errorQuda("Invalid value %d for solution_accumulator_pipeline", Np); // Determine whether or not we're doing a heavy quark residual const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; @@ -231,8 +86,6 @@ namespace quda { // just in case HQ residual solves are split into a separate file if (use_heavy_quark_res) errorQuda("The \"vanilla\" CG solver does not support HQ residual solves"); - // whether to select alternative reliable updates - bool alternative_reliable = param.use_alternative_reliable; /** When CG is used as a preconditioner, and we disable the `advanced features`, these features are turned off: - Reliable updates @@ -244,15 +97,14 @@ namespace quda { if (!param.is_preconditioner) getProfile().TPSTART(QUDA_PROFILE_INIT); - double b2 = blas::norm2(b); + // whether to select alternative reliable updates + bool alternative_reliable = param.use_alternative_reliable; + + auto b2 = blas::norm2(b); // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0 && param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_INIT); - printfQuda("Warning: inverting on zero-field source\n"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; + if (is_zero_src(x, b, b2)) { + getProfile().TPSTOP(QUDA_PROFILE_INIT); return; } @@ -260,7 +112,7 @@ namespace quda { if (param.deflate) { // Construct the eigensolver and deflation space if requested. - constructDeflationSpace(b, matEig); + constructDeflationSpace(b[0], matEig); if (deflate_compute) { // compute the deflation space. if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_INIT); @@ -277,27 +129,28 @@ namespace quda { const double u = precisionEpsilon(param.precision_sloppy); const double uhigh = precisionEpsilon(); // solver precision - double Anorm = 0; - double beta = 0; + double Anorm = 0.0; + vector beta(b.size(), 0.0); // for alternative reliable updates if (advanced_feature && alternative_reliable) { // estimate norm for reliable updates - mat(r, b); - Anorm = sqrt(blas::norm2(r)/b2); + mat(r[0], b[0]); + Anorm = sqrt(blas::norm2(r[0]) / b2[0]); } // compute initial residual - double r2 = 0.0; + vector r2(b2.size(), 0.0); if (advanced_feature && param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { // Compute r = b - A * x mat(r, x); r2 = blas::xmyNorm(b, r); - if (b2 == 0) b2 = r2; + for (auto i = 0u; i < b.size(); i++) + if (b2[i] == 0) b2[i] = r2[i]; // y contains the original guess. blas::copy(y, x); } else { - if (&r != &b) blas::copy(r, b); + blas::copy(r, b); r2 = b2; blas::zero(y); } @@ -310,20 +163,24 @@ namespace quda { } blas::zero(x); - if (&x != &xSloppy) blas::zero(xSloppy); - blas::copy(rSloppy,r); - - ColorSpinorParam csParam(rSloppy); - csParam.create = QUDA_NULL_FIELD_CREATE; - XUpdateBatch x_update_batch(Np, p_init ? *p_init : rSloppy, csParam); - - double r2_old = 0.0; - if (r2_old_init != 0.0 and p_init) { - r2_old = r2_old_init; - Complex rp = blas::cDotProduct(rSloppy, x_update_batch.get_current_field()) / (r2); - blas::caxpy(-rp, rSloppy, x_update_batch.get_current_field()); - beta = r2 / r2_old; - blas::xpayz(rSloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_current_field()); + if (param.use_sloppy_partial_accumulator) blas::zero(x_sloppy); + blas::copy(r_sloppy, r); + + auto csParam(r_sloppy[0]); + std::vector x_update_batch(b.size()); + for (auto i = 0u; i < b.size(); i++) + x_update_batch[i] = XUpdateBatch(Np, !p_init[i].empty() ? p_init[i] : r_sloppy[i], csParam); + + vector r2_old(r2.size(), 0.0); + for (auto i = 0u; i < b.size(); i++) { + if (r2_old_init[i] != 0.0 and !p_init[i].empty()) { + // FIXME vectorize this + r2_old[i] = r2_old_init[i]; + Complex rp = blas::cDotProduct(r_sloppy[i], x_update_batch[i].get_current_field()) / (r2[i]); + blas::caxpy(-rp, r_sloppy[i], x_update_batch[i].get_current_field()); + beta[i] = r2[i] / r2_old[i]; + blas::xpayz(r_sloppy[i], beta[i], x_update_batch[i].get_current_field(), x_update_batch[i].get_current_field()); + } } if (!param.is_preconditioner) { @@ -331,10 +188,9 @@ namespace quda { getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); } - double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver - auto alpha = std::make_unique(Np); - double pAp; + vector pAp(b.size()); if (!param.is_preconditioner) { getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); @@ -343,9 +199,9 @@ namespace quda { int k = 0; - PrintStats("CG", k, r2, b2, 0.0); + PrintStats("CG", k, r2, b2); - bool converged = convergenceL2(r2, 0.0, stop, 0.0); + bool converged = convergenceL2(r2, stop); ReliableUpdatesParams ru_params; @@ -359,30 +215,55 @@ namespace quda { ru_params.maxResIncreaseTotal = param.max_res_increase_total; ru_params.use_heavy_quark_res = false; // since we've removed HQ residual support - ReliableUpdates ru(ru_params, r2); + ReliableUpdates ru(ru_params, r2[0]); + + auto get_p = [](std::vector &x_update_batch, bool next = false) { + vector_ref p; + p.reserve(x_update_batch.size()); + for (auto &x : x_update_batch) p.push_back(next ? x.get_next_field() : x.get_current_field()); + return p; + }; + + auto get_alpha = [](std::vector &x_update_batch) { + vector alpha; + alpha.reserve(x_update_batch.size()); + for (auto &x : x_update_batch) alpha.push_back(x.get_current_alpha()); + return alpha; + }; while ( !converged && k < param.maxiter ) { - matSloppy(Ap, x_update_batch.get_current_field()); - double sigma; + auto p = get_p(x_update_batch); + auto p_next = get_p(x_update_batch, true); + matSloppy(Ap, p); + + vector sigma(b.size()); bool breakdown = false; if (advanced_feature && param.pipeline) { - double Ap2; - if(alternative_reliable){ - double4 quadruple = blas::quadrupleCGReduction(rSloppy, Ap, x_update_batch.get_current_field()); - r2 = quadruple.x; - Ap2 = quadruple.y; - pAp = quadruple.z; - ru.update_ppnorm(quadruple.w); + vector Ap2(b.size()); + if (alternative_reliable) { + auto quadruple = blas::quadrupleCGReduction(r_sloppy, Ap, p); + for (auto i = 0u; i < b.size(); i++) { + r2[i] = quadruple[i].x; + Ap2[i] = quadruple[i].y; + pAp[i] = quadruple[i].z; + } + ru.update_ppnorm(quadruple[0].w); // using 0th system for RU } else { - double3 triplet = blas::tripleCGReduction(rSloppy, Ap, x_update_batch.get_current_field()); - r2 = triplet.x; Ap2 = triplet.y; pAp = triplet.z; + auto triplet = blas::tripleCGReduction(r_sloppy, Ap, p); + for (auto i = 0u; i < b.size(); i++) { + r2[i] = triplet[i].x; + Ap2[i] = triplet[i].y; + pAp[i] = triplet[i].z; + } } r2_old = r2; - x_update_batch.get_current_alpha() = r2 / pAp; - sigma = x_update_batch.get_current_alpha() * (x_update_batch.get_current_alpha() * Ap2 - pAp); - if (sigma < 0.0 || ru.steps_since_reliable == 0) { // sigma condition has broken down - r2 = blas::axpyNorm(-x_update_batch.get_current_alpha(), Ap, rSloppy); + for (auto i = 0u; i < b.size(); i++) { + x_update_batch[i].get_current_alpha() = r2[i] / pAp[i]; + sigma[i] = x_update_batch[i].get_current_alpha() * (x_update_batch[i].get_current_alpha() * Ap2[i] - pAp[i]); + } + if (sigma[0] < 0.0 || ru.steps_since_reliable == 0) { // sigma condition has broken down + r2 = blas::axpyNorm(-get_alpha(x_update_batch), Ap, r_sloppy); sigma = r2; breakdown = true; } @@ -393,69 +274,72 @@ namespace quda { // alternative reliable updates, if (advanced_feature && alternative_reliable) { - double3 pAppp = blas::cDotProductNormA(x_update_batch.get_current_field(), Ap); - pAp = pAppp.x; - ru.update_ppnorm(pAppp.z); + auto pAppp = blas::cDotProductNormA(p, Ap); + for (auto i = 0u; i < b.size(); i++) pAp[i] = pAppp[i].x; + ru.update_ppnorm(pAppp[0].z); // using 0th system for RU } else { - pAp = blas::reDotProduct(x_update_batch.get_current_field(), Ap); + pAp = blas::reDotProduct(p, Ap); } - x_update_batch.get_current_alpha() = r2 / pAp; + for (auto i = 0u; i < b.size(); i++) x_update_batch[i].get_current_alpha() = r2[i] / pAp[i]; // here we are deploying the alternative beta computation - double2 cg_norm = blas::axpyCGNorm(-x_update_batch.get_current_alpha(), Ap, rSloppy); - r2 = cg_norm.x; // (r_new, r_new) - sigma = cg_norm.y >= 0.0 ? cg_norm.y : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks + auto cg_norm = blas::axpyCGNorm(-get_alpha(x_update_batch), Ap, r_sloppy); + for (auto i = 0u; i < b.size(); i++) { + r2[i] = cg_norm[i].x; // (r_new, r_new) + sigma[i] = cg_norm[i].y >= 0.0 ? cg_norm[i].y : r2[i]; // use r2 if (r_k+1, r_k+1-r_k) breaks + } } // reliable update conditions - ru.update_rNorm(sqrt(r2)); + ru.update_rNorm(sqrt(r2[0])); if (advanced_feature) { - ru.evaluate(r2_old); + ru.evaluate(r2_old[0]); // force a reliable update if we are within target tolerance (only if doing reliable updates) - if (convergenceL2(r2, 0.0, stop, 0.0) && param.delta >= param.tol) ru.set_updateX(); + if (convergenceL2(r2, stop) && param.delta >= param.tol) ru.set_updateX(); } if (!ru.trigger()) { - beta = sigma / r2_old; // use the alternative beta computation + for (auto i = 0u; i < beta.size(); i++) beta[i] = sigma[i] / r2_old[i]; // use the alternative beta computation if (advanced_feature && param.pipeline && !breakdown) { if (Np == 1) { - blas::tripleCGUpdate(x_update_batch.get_current_alpha(), beta, Ap, xSloppy, rSloppy, - x_update_batch.get_current_field()); + blas::tripleCGUpdate(get_alpha(x_update_batch), beta, Ap, x_sloppy, r_sloppy, p); } else { errorQuda("Not implemented pipelined CG with Np > 1"); } } else { + if (Np == 1) { // with Np=1 we just run regular fusion between x and p updates - blas::axpyZpbx(x_update_batch.get_current_alpha(), x_update_batch.get_current_field(), xSloppy, rSloppy, - beta); + blas::axpyZpbx(get_alpha(x_update_batch), p, x_sloppy, r_sloppy, beta); } else { - if (x_update_batch.is_container_full()) { x_update_batch.accumulate_x(xSloppy); } + for (auto i = 0u; i < b.size(); i++) { + if (x_update_batch[i].is_container_full()) x_update_batch[i].accumulate_x(x_sloppy[i]); + } // p[(k+1)%Np] = r + beta * p[k%Np] - blas::xpayz(rSloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_next_field()); + blas::xpayz(r_sloppy, beta, p, p_next); } } // alternative reliable updates - if (advanced_feature) { ru.accumulate_norm(x_update_batch.get_current_alpha()); } + if (advanced_feature) { ru.accumulate_norm(x_update_batch[0].get_current_alpha()); } } else { - x_update_batch.accumulate_x(xSloppy); - x_update_batch.reset_next(); - - blas::copy(x, xSloppy); // nop when these pointers alias + for (auto i = 0u; i < b.size(); i++) { + x_update_batch[i].accumulate_x(x_sloppy[i]); + x_update_batch[i].reset_next(); + } + blas::xpy(x_sloppy, y); // swap these around? - blas::xpy(x, y); // swap these around? mat(r, y); // here we can use x as tmp r2 = blas::xmyNorm(b, r); - if (param.deflate && sqrt(r2) < ru.maxr_deflate * param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < ru.maxr_deflate * param.tol_restart) { // Deflate and accumulate to solution vector eig_solve->deflate(y, r, evecs, evals, true); @@ -463,50 +347,56 @@ namespace quda { mat(r, y); r2 = blas::xmyNorm(b, r); - ru.update_maxr_deflate(r2); + ru.update_maxr_deflate(r2[0]); } - blas::copy(rSloppy, r); //nop when these pointers alias - blas::zero(xSloppy); + blas::copy(r_sloppy, r); // nop when these pointers alias + blas::zero(x_sloppy); - if (advanced_feature) { ru.update_norm(r2, y); } + if (advanced_feature) { ru.update_norm(r2[0], y[0]); } if (advanced_feature) { // needed as a "dummy parameter" to reliable_break. bool L2breakdown = false; - if (ru.reliable_break(r2, stop, L2breakdown, 0)) { break; } + if (ru.reliable_break(r2[0], stop[0], L2breakdown, 0)) { break; } } // explicitly restore the orthogonality of the gradient vector - Complex rp = blas::cDotProduct(rSloppy, x_update_batch.get_current_field()) / (r2); - blas::caxpy(-rp, rSloppy, x_update_batch.get_current_field()); + auto p = get_p(x_update_batch); + auto p_next = get_p(x_update_batch, true); + + auto rp = blas::cDotProduct(r_sloppy, p); + for (auto i = 0u; i < b.size(); i++) rp[i] /= r2[i]; + blas::caxpy(-rp, r_sloppy, p); - beta = r2 / r2_old; - blas::xpayz(rSloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_next_field()); + for (auto i = 0u; i < beta.size(); i++) beta[i] = r2[i] / r2_old[i]; + blas::xpayz(r_sloppy, beta, p, p_next); - ru.reset(r2); + ru.reset(r2[0]); } breakdown = false; k++; - PrintStats("CG", k, r2, b2, 0.0); + PrintStats("CG", k, r2, b2); // check convergence - converged = convergenceL2(r2, 0.0, stop, 0.0); + converged = convergenceL2(r2, stop); // if we have converged and need to update any trailing solutions - if (converged && ru.steps_since_reliable > 0 && !x_update_batch.is_container_full()) { - x_update_batch.accumulate_x(xSloppy); - } + for (auto i = 0u; i < b.size(); i++) { + if (converged && ru.steps_since_reliable > 0 && !x_update_batch[i].is_container_full()) { + x_update_batch[i].accumulate_x(x_sloppy[i]); + } - if (ru.steps_since_reliable == 0) { - x_update_batch.reset(); - } else { - ++x_update_batch; + if (ru.steps_since_reliable == 0) { + x_update_batch[i].reset(); + } else { + ++x_update_batch[i]; + } } } - blas::copy(x, xSloppy); + blas::copy(x, x_sloppy); blas::xpy(y, x); if (!param.is_preconditioner) { @@ -523,25 +413,29 @@ namespace quda { if (advanced_feature && param.compute_true_res) { // compute the true residuals mat(r, x); - param.true_res = sqrt(blas::xmyNorm(b, r) / b2); - param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + auto true_r2 = blas::xmyNorm(b, r); + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(true_r2[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } } - PrintSummary("CG", k, r2, b2, stop, 0.0); + PrintSummary("CG", k, r2, b2, stop); if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); if (param.is_preconditioner) commGlobalReductionPop(); } - ColorSpinorField &CG::get_residual() + cvector_ref CG::get_residual() { if (!init) errorQuda("No residual vector present"); return r; } // Separate HQ residual codepath - void CG::hqsolve(ColorSpinorField &x, ColorSpinorField &b) + void CG::hqsolve(cvector_ref &x, cvector_ref &b) { logQuda(QUDA_VERBOSE, "Performing a HQ CG solve\n"); @@ -557,7 +451,7 @@ namespace quda { getProfile().TPSTART(QUDA_PROFILE_INIT); - double b2 = blas::norm2(b); + vector b2 = blas::norm2(b); // Detect whether this is a pure double solve or not; informs the necessity of some stability checks bool is_pure_double = (param.precision == QUDA_DOUBLE_PRECISION && param.precision_sloppy == QUDA_DOUBLE_PRECISION); @@ -565,12 +459,8 @@ namespace quda { bool heavy_quark_restart = false; // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0 && param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + if (is_zero_src(x, b, b2)) { getProfile().TPSTOP(QUDA_PROFILE_INIT); - printfQuda("Warning: inverting on zero-field source\n"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; return; } @@ -585,36 +475,45 @@ namespace quda { const double hq_res_stall_check = is_pure_double ? 0. : uhigh * uhigh * 1e-60; // compute initial residual - double r2 = 0.0; + vector r2(b.size()); if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { // Compute r = b - A * x mat(r, x); r2 = blas::xmyNorm(b, r); - if (b2 == 0) b2 = r2; + for (auto i = 0u; i < b.size(); i++) + if (b2[i] == 0) b2[i] = r2[i]; // y contains the original guess. blas::copy(y, x); } else { - if (&r != &b) blas::copy(r, b); + blas::copy(r, b); r2 = b2; blas::zero(y); } blas::zero(x); - if (&x != &xSloppy) blas::zero(xSloppy); - blas::copy(rSloppy, r); - blas::copy(p, rSloppy); + if (param.use_sloppy_partial_accumulator) blas::zero(x_sloppy); + blas::copy(r_sloppy, r); + blas::copy(p, r_sloppy); - double r2_old = 0.0; + vector r2_old(b.size(), 0.0); getProfile().TPSTOP(QUDA_PROFILE_INIT); getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); - double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop_hq = std::vector(b.size(), param.tol_hq); + + auto get_hq_res = [](cvector_ref &x, cvector_ref &r) { + auto hq_nrm = blas::HeavyQuarkResidualNorm(x, r); + vector hq_res(hq_nrm.size()); + for (auto i = 0u; i < hq_nrm.size(); i++) hq_res[i] = sqrt(hq_nrm[i].z); + return hq_res; + }; // compute the initial heavy quark residual - double hq_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + vector hq_res = get_hq_res(x, r); - double alpha, beta, sigma, pAp; + vector alpha(b.size()), beta(b.size()), sigma(b.size()), pAp(b.size()); // Whether or not we also need to compute the L2 norm const bool L2_required = param.residual_type & (QUDA_L2_RELATIVE_RESIDUAL | QUDA_L2_ABSOLUTE_RESIDUAL); @@ -631,20 +530,20 @@ namespace quda { PrintStats("CG", k, r2, b2, hq_res); - bool converged = convergence(r2, hq_res, stop, param.tol_hq); + bool converged = convergence(r2, hq_res, stop, stop_hq); // Various parameters related to restarts // Trackers for the L2 norm: // rNorm: current iterated |r| // r0Norm: computed |r| at the last reliable update - double rNorm = sqrt(r2); - double r0Norm = rNorm; + auto rNorm = sqrt(r2[0]); + auto r0Norm = rNorm; // If the computed |r| goes above r0Norm between reliable updates, // update this ceiling. This goes into "R" type reliable updates. - double maxrx = L2breakdown ? hq_res : rNorm; - double maxrr = L2breakdown ? hq_res : rNorm; + double maxrx = L2breakdown ? hq_res[0] : rNorm; + double maxrr = L2breakdown ? hq_res[0] : rNorm; // Triggers for explicitly counting residual updates and checking for L2breakdown. // * updateX broadly maps to if the iterated residual has dropped by a factor of delta @@ -665,7 +564,7 @@ namespace quda { // Trackers for the HQ residual // hq0Res: computed HQ residual at the last reliable updated - double hq0Res = hq_res; + auto hq0Res = hq_res; // Counter for the number of times in a row the computed heavy quark residual has // jumped above the previously computed heavy quark residual. @@ -688,20 +587,22 @@ namespace quda { pAp = blas::reDotProduct(p, Ap); - alpha = r2 / pAp; + for (auto i = 0u; i < alpha.size(); i++) alpha[i] = r2[i] / pAp[i]; // here we are deploying the alternative beta computation - double2 cg_norm = blas::axpyCGNorm(-alpha, Ap, rSloppy); - r2 = cg_norm.x; // (r_new, r_new) - sigma = cg_norm.y >= 0.0 ? cg_norm.y : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks - rNorm = sqrt(r2); + auto cg_norm = blas::axpyCGNorm(-alpha, Ap, r_sloppy); + for (auto i = 0u; i < cg_norm.size(); i++) { + r2[i] = cg_norm[i].x; // (r_new, r_new) + sigma[i] = cg_norm[i].y >= 0.0 ? cg_norm[i].y : r2[i]; // use r2 if (r_k+1, r_k+1-r_k) breaks + } + rNorm = sqrt(r2[0]); // If the iterated norm has dropped by more than a factor of delta, trigger // an update. The baseline we check against differs depending on if // we're still checking the L2 norm, or if that has converged/broken down and we're // now looking at the HQ residual. - if (!L2breakdown && (L2_required || convergenceL2(r2, hq_res, stop, param.tol_hq))) { + if (!L2breakdown && (L2_required || convergenceL2(r2, stop))) { // L2 based reliable update // If the iterated residual norm has gone above the most recent "baseline" norm, @@ -717,42 +618,50 @@ namespace quda { updateR = ((rNorm < param.delta * maxrr && r0Norm <= maxrr) || updateX); } else { // hqresidual based reliable update - if (hq_res > maxrx) maxrx = hq_res; - if (hq_res > maxrr) maxrr = hq_res; + if (hq_res[0] > maxrx) maxrx = hq_res[0]; + if (hq_res[0] > maxrr) maxrr = hq_res[0]; // I'm making the decision to use `param.delta` for the hq_res check because // in some regards it's an L2-esque norm... // Has the iterated heavy quark residual dropped by a factor of delta^2 from the last // computed norm? - updateX = (hq_res < param.delta * param.delta * hq0Res && r0Norm <= maxrx); + updateX = (hq_res[0] < param.delta * param.delta * hq0Res[0] && r0Norm <= maxrx); // Has the iterated heavy quark residual dropped by a factor of delta relative // to the largest the iterated norm has been since the last update? - updateR = ((hq_res < param.delta * param.delta * maxrr && hq0Res <= maxrr) || updateX); + updateR = ((hq_res[0] < param.delta * param.delta * maxrr && hq0Res[0] <= maxrr) || updateX); } // force a reliable update if we are within target tolerance (only if doing reliable updates) - if (convergence(r2, hq_res, stop, param.tol_hq) && param.delta >= param.tol) updateX = true; + if (convergence(r2, hq_res, stop, stop_hq) && param.delta >= param.tol) updateX = true; // force a reliable update based on the HQ residual if L2 breakdown has already happened - if (L2breakdown && (convergenceHQ(r2, hq_res, stop, param.tol_hq) || (r2 / b2) < hq_res_stall_check) + if (L2breakdown && (convergenceHQ(hq_res, stop_hq) || (r2[0] / b2[0]) < hq_res_stall_check) && param.delta >= param.tol) updateX = true; if (!(updateR || updateX)) { // No reliable update needed - beta = sigma / r2_old; // use the alternative beta computation + for (auto i = 0u; i < beta.size(); i++) beta[i] = sigma[i] / r2_old[i]; // use the alternative beta computation + + blas::axpyZpbx(alpha, p, x_sloppy, r_sloppy, beta); - blas::axpyZpbx(alpha, p, xSloppy, rSloppy, beta); + auto get_hq_res2 = [](cvector_ref &x, cvector_ref &y, + cvector_ref &r) { + auto hq_nrm = blas::xpyHeavyQuarkResidualNorm(x, y, r); + vector hq_res(hq_nrm.size()); + for (auto i = 0u; i < hq_nrm.size(); i++) hq_res[i] = sqrt(hq_nrm[i].z); + return hq_res; + }; if (k % param.heavy_quark_check == 0) { - if (xSloppy.Precision() != rSloppy.Precision()) { - blas::copy(r, rSloppy); - hq_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, y, r).z); + if (param.precision != param.precision_sloppy) { + blas::copy(r, r_sloppy); + hq_res = get_hq_res2(x_sloppy, y, r); } else { - hq_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, y, rSloppy).z); + hq_res = get_hq_res2(x_sloppy, y, r_sloppy); } } @@ -762,23 +671,23 @@ namespace quda { // We're performing a reliable update // Accumulate p into x, accumulate x into the total solution y, explicitly recompute the residual vector - blas::axpy(alpha, p, xSloppy); - blas::copy(x, xSloppy); // no op when these pointers alias + blas::axpy(alpha, p, x_sloppy); + blas::copy(x, x_sloppy); // no op when these pointers alias blas::xpy(x, y); mat(r, y); // Recompute the exact residual and heavy quark residual r2 = blas::xmyNorm(b, r); - rNorm = sqrt(r2); - hq_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z); + rNorm = sqrt(r2[0]); + hq_res = get_hq_res(y, r); // Copy and update fields - blas::copy(rSloppy, r); // no op when these pointers alias - blas::zero(xSloppy); + blas::copy(r_sloppy, r); // no op when these pointers alias + blas::zero(x_sloppy); // Check and see if we're "done" with the L2 norm. This could be because // we were already done with it, we never needed it, or the L2 norm has finally converged. - if (!L2breakdown && convergenceL2(r2, hq_res, stop, param.tol_hq)) L2breakdown = true; + if (!L2breakdown && convergenceL2(r2, stop)) L2breakdown = true; // Depending on if we're still grinding on the L2 norm or if we've moved along to just // the HQ norm, we reset the baselines for reliable updates that get used on the @@ -792,8 +701,8 @@ namespace quda { } else { // If we've made it to the HQ norm, the new baseline is the freshly recomputed // heavy quark residual - maxrr = hq_res; - maxrx = hq_res; + maxrr = hq_res[0]; + maxrx = hq_res[0]; // Once we're dealing with the heavy quark residual, we perform a *hard* CG // restart at every reliable update via setting the search vector `p` to the current @@ -828,7 +737,7 @@ namespace quda { // ...tell the world about it too. warningQuda( "new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), r0Norm, resIncreaseTotal); + sqrt(r2[0]), r0Norm, resIncreaseTotal); // If the norm is ridiculously small in magnitude, we've exceeded the maximums on various // ways we keep track of residual increases, or the L2 norm converged, we say "we're good here" @@ -840,8 +749,8 @@ namespace quda { // We also have to do a logic correction, switching the reliable update baselines we set above // from the L2 norm over to the HQ residual. - maxrr = hq_res; - maxrx = hq_res; + maxrr = hq_res[0]; + maxrx = hq_res[0]; } } else { // This variable counts the number of times in a row the computed residual has gone up, @@ -856,8 +765,8 @@ namespace quda { hqresIncrease++; // Tell the world about it - warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", hq_res, - hq0Res); + warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", + hq_res[0], hq0Res[0]); // And if it's increased too many times in a row, flunk out of the solve. if (hqresIncrease > param.max_hq_res_increase) { @@ -876,17 +785,18 @@ namespace quda { if (heavy_quark_restart) { // If we're in the HQ residual part of the solve, we just do a hard CG restart. logQuda(QUDA_DEBUG_VERBOSE, "HQ restart == hard CG restart\n"); - blas::copy(p, rSloppy); + blas::copy(p, r_sloppy); heavy_quark_restart = false; } else { // If we're still in the L2 norm part of the solve, we explicitly restore // the orthogonality of the gradient vector, recompute beta, update `p`, and carry on with our lives. logQuda(QUDA_DEBUG_VERBOSE, "Regular restart == explicit gradient vector re-orthogonalization\n"); - Complex rp = blas::cDotProduct(rSloppy, p) / (r2); - blas::caxpy(-rp, rSloppy, p); + auto rp = blas::cDotProduct(r_sloppy, p); + for (auto i = 0u; i < b.size(); i++) rp[i] / r2[i]; + blas::caxpy(-rp, r_sloppy, p); - beta = r2 / r2_old; - blas::xpayz(rSloppy, beta, p, p); + for (auto i = 0u; i < b.size(); i++) beta[i] = r2[i] / r2_old[i]; + blas::xpayz(r_sloppy, beta, p, p); } // Last, we increment the reliable update counter, reset the number of steps since the last reliable update, @@ -894,7 +804,7 @@ namespace quda { // reliable update. rUpdate++; steps_since_reliable = 0; - r0Norm = sqrt(r2); + r0Norm = sqrt(r2[0]); hq0Res = hq_res; } @@ -903,18 +813,18 @@ namespace quda { PrintStats("CG", k, r2, b2, hq_res); // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently - converged = convergence(r2, hq_res, stop, param.tol_hq); + converged = convergence(r2, hq_res, stop, stop_hq); // check for recent enough reliable updates of the HQ residual if we use it // L2 is converged or precision maxed out for L2 - bool L2done = L2breakdown || convergenceL2(r2, hq_res, stop, param.tol_hq); + bool L2done = L2breakdown || convergenceL2(r2, stop); // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update - bool HQdone = (steps_since_reliable == 0 && param.delta > 0) && convergenceHQ(r2, hq_res, stop, param.tol_hq); + bool HQdone = (steps_since_reliable == 0 && param.delta > 0) && convergenceHQ(hq_res, stop_hq); converged = L2done && HQdone; } - blas::copy(x, xSloppy); + blas::copy(x, x_sloppy); blas::xpy(y, x); getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); @@ -929,11 +839,15 @@ namespace quda { if (param.compute_true_res) { // compute the true residuals mat(r, x); - param.true_res = sqrt(blas::xmyNorm(b, r) / b2); - param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + auto true_r2 = blas::xmyNorm(b, r); + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(true_r2[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } } - PrintSummary("CG", k, r2, b2, stop, param.tol_hq); + PrintSummary("CG", k, r2, b2, stop, stop_hq); getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); } @@ -1087,7 +1001,7 @@ namespace quda { int k = 0; - PrintStats("CG", k, r2avg / param.num_src, b2avg, 0.); + PrintStats("CG", k, r2avg / param.num_src, b2avg); bool allconverged = true; bool converged[QUDA_MAX_MULTI_SHIFT]; for (int i = 0; i < param.num_src; i++) { @@ -1217,7 +1131,7 @@ namespace quda { } k++; - PrintStats("CG", k, r2avg / param.num_src, b2avg, 0); + PrintStats("CG", k, r2avg / param.num_src, b2avg); // check convergence allconverged = true; for (int i = 0; i < param.num_src; i++) { @@ -1245,7 +1159,7 @@ namespace quda { param.true_res_offset[i] = param.true_res; param.true_res_hq_offset[i] = param.true_res_hq; - PrintSummary("CG", k, r2(i, i).real(), b2[i], stop[i], 0.0); + PrintSummary("CG", k, r2(i, i).real(), b2[i], stop[i]); } getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); @@ -1803,7 +1717,7 @@ void CG::solve(ColorSpinorField& x, ColorSpinorField& b) { param.true_res_offset[i] = param.true_res; param.true_res_hq_offset[i] = param.true_res_hq; - PrintSummary("CG", k, r2(i,i).real(), b2[i], stop[i], 0.0); + PrintSummary("CG", k, r2(i, i).real(), b2[i], stop[i]); } getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); diff --git a/lib/inv_cgne.cpp b/lib/inv_cgne.cpp new file mode 100644 index 0000000000..12fe55ec95 --- /dev/null +++ b/lib/inv_cgne.cpp @@ -0,0 +1,103 @@ +#include "invert_quda.h" + +namespace quda +{ + + CGNE::CGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, + const DiracMatrix &matEig, SolverParam ¶m) : + Solver(mat, matSloppy, matPrecon, matEig, param), + mmdag(mat.Expose()), + mmdagSloppy(matSloppy.Expose()), + mmdagPrecon(matPrecon.Expose()), + mmdagEig(matEig.Expose()) + { + switch (param.inv_type) { + case QUDA_CGNE_INVERTER: cg = std::make_unique(mmdag, mmdagSloppy, mmdagPrecon, mmdagEig, param); break; + case QUDA_CA_CGNE_INVERTER: cg = std::make_unique(mmdag, mmdagSloppy, mmdagPrecon, mmdagEig, param); break; + case QUDA_CG3NE_INVERTER: cg = std::make_unique(mmdag, mmdagSloppy, mmdagPrecon, param); break; + default: errorQuda("Unexpected CG solver type %d", param.inv_type); + } + } + + void CGNE::create(cvector_ref &x, cvector_ref &b) + { + Solver::create(x, b); + if (!init || xe.size() != b.size()) { + ColorSpinorParam csParam(x[0]); + csParam.create = QUDA_NULL_FIELD_CREATE; + resize(xe, b.size(), csParam); + csParam.create = QUDA_ZERO_FIELD_CREATE; + resize(ye, b.size(), csParam); + init = true; + } + } + + cvector_ref CGNE::get_residual() + { + if (!init) errorQuda("No residual vector present"); + if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); + // CG residual will match the CGNE residual (FIXME: but only with zero initial guess?) + return param.use_init_guess ? xe : cg->get_residual(); + } + + // CGNE: M Mdag y = b is solved; x = Mdag y is returned as solution. + void CGNE::operator()(cvector_ref &x, cvector_ref &b) + { + if (param.maxiter == 0 || param.Nsteps == 0) { + if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); + return; + } + + create(x, b); + + const int iter0 = param.iter; + auto b2 = param.compute_true_res ? blas::norm2(b) : vector(b.size(), 0.0); + + if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { + // compute initial residual + mmdag.Expose()->M(xe, x); + + if (param.compute_true_res) { + bool is_zero = true; + for (auto i = 0u; i < b2.size(); i++) { + is_zero = is_zero || b2[i] == 0.0; + if (b2[i] == 0.0 && !is_zero) errorQuda("Mixture of zero and non-zero sources not supported"); + } + if (is_zero) b2 = blas::xmyNorm(b, xe); + } else { + blas::xpay(b, -1.0, xe); + } + + // compute solution to residual equation + cg->operator()(ye, xe); + + mmdag.Expose()->Mdag(xe, ye); + + // compute full solution + blas::xpy(xe, x); + } else { + cg->operator()(ye, b); + mmdag.Expose()->Mdag(x, ye); + } + + if (param.compute_true_res || (param.use_init_guess && param.return_residual)) { + // compute the true residual + mmdag.Expose()->M(xe, x); + blas::xpay(b, -1.0, xe); // xe now holds the residual + + vector r2(b2.size()); + if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { + auto hq = blas::HeavyQuarkResidualNorm(x, xe); + for (auto i = 0u; i < b.size(); i++) { + param.true_res_hq[i] = sqrt(hq[i].z); + r2[i] = hq[i].y; + } + } else { + r2 = blas::norm2(xe); + } + for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); + PrintSummary("CGNE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type)); + } + } + +} // namespace quda diff --git a/lib/inv_cgnr.cpp b/lib/inv_cgnr.cpp new file mode 100644 index 0000000000..e143ad1187 --- /dev/null +++ b/lib/inv_cgnr.cpp @@ -0,0 +1,90 @@ +#include "invert_quda.h" + +namespace quda +{ + + CGNR::CGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, + const DiracMatrix &matEig, SolverParam ¶m) : + Solver(mat, mdagmSloppy, mdagmPrecon, mdagmEig, param), + mdagm(mat.Expose()), + mdagmSloppy(matSloppy.Expose()), + mdagmPrecon(matPrecon.Expose()), + mdagmEig(matEig.Expose()) + { + switch (param.inv_type) { + case QUDA_CGNR_INVERTER: cg = std::make_unique(mdagm, mdagmSloppy, mdagmPrecon, mdagmEig, param); break; + case QUDA_CA_CGNR_INVERTER: cg = std::make_unique(mdagm, mdagmSloppy, mdagmPrecon, mdagmEig, param); break; + case QUDA_CG3NR_INVERTER: cg = std::make_unique(mdagm, mdagmSloppy, mdagmPrecon, param); break; + default: errorQuda("Unexpected CG solver type %d", param.inv_type); + } + } + + void CGNR::create(cvector_ref &x, cvector_ref &b) + { + Solver::create(x, b); + if (!init || br.size() != b.size()) { + ColorSpinorParam csParam(b[0]); + csParam.create = QUDA_ZERO_FIELD_CREATE; + resize(br, b.size(), csParam); + init = true; + } + } + + cvector_ref CGNR::get_residual() + { + if (!init) errorQuda("No residual vector present"); + if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); + return br; + } + + // CGNR: Mdag M x = Mdag b is solved. + void CGNR::operator()(cvector_ref &x, cvector_ref &b) + { + if (param.maxiter == 0 || param.Nsteps == 0) { + if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); + return; + } + + create(x, b); + + const int iter0 = param.iter; + vector b2(b.size(), 0.0); + if (param.compute_true_res) { + b2 = blas::norm2(b); + bool is_zero = true; + for (auto i = 0u; i < b2.size(); i++) { + is_zero = is_zero && b2[i] == 0.0; + if (b2[i] == 0.0 && !is_zero) errorQuda("Mixture of zero and non-zero sources not supported"); + } + if (is_zero) { // compute initial residual vector + mdagm.Expose()->M(br, x); + b2 = blas::norm2(br); + } + } + + mdagm.Expose()->Mdag(br, b); + cg->operator()(x, br); + + if (param.compute_true_res || param.return_residual) { + // compute the true residual + mdagm.Expose()->M(br, x); + blas::xpay(b, -1.0, br); // br now holds the residual + + if (param.compute_true_res) { + vector r2(b.size()); + if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { + auto hq = blas::HeavyQuarkResidualNorm(x, br); + for (auto i = 0u; i < b.size(); i++) { + param.true_res_hq[i] = sqrt(hq[i].z); + r2[i] = hq[i].y; + } + } else { + r2 = blas::norm2(br); + } + for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); + PrintSummary("CGNR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type)); + } + } + } + +} // namespace quda diff --git a/lib/inv_eigcg_quda.cpp b/lib/inv_eigcg_quda.cpp index ad0747ace4..9961f8ef8e 100644 --- a/lib/inv_eigcg_quda.cpp +++ b/lib/inv_eigcg_quda.cpp @@ -335,7 +335,8 @@ namespace quda { /* * This is a solo precision solver. */ - int IncEigCG::eigCGsolve(ColorSpinorField &x, ColorSpinorField &b) { + int IncEigCG::eigCGsolve(ColorSpinorField &x, const ColorSpinorField &b) + { int k=0; @@ -358,7 +359,7 @@ namespace quda { eigcg_args = new EigCGArgs(param.m, param.n_ev); // need only deflation meta structure csParam.create = QUDA_COPY_FIELD_CREATE; - csParam.field = &b; + csParam.field = &const_cast(b); rp = ColorSpinorField::Create(csParam); csParam.create = QUDA_ZERO_FIELD_CREATE; yp = ColorSpinorField::Create(csParam); @@ -525,7 +526,8 @@ namespace quda { return k; } - int IncEigCG::initCGsolve(ColorSpinorField &x, ColorSpinorField &b) { + int IncEigCG::initCGsolve(ColorSpinorField &x, const ColorSpinorField &b) + { int k = 0; //Start init CG iterations: deflated_solver *defl_p = static_cast(param.deflation_op); @@ -586,7 +588,7 @@ namespace quda { return k; } - void IncEigCG::operator()(ColorSpinorField &out, ColorSpinorField &in) + void IncEigCG::operator()(ColorSpinorField &out, const ColorSpinorField &in) { if(param.rhs_idx == 0) max_eigcg_cycles = param.eigcg_max_restarts; diff --git a/lib/inv_gcr_quda.cpp b/lib/inv_gcr_quda.cpp index 45db2e5536..8e9c1df32e 100644 --- a/lib/inv_gcr_quda.cpp +++ b/lib/inv_gcr_quda.cpp @@ -21,7 +21,7 @@ namespace quda { return ds + 0.000001*dus; } - void GCR::computeBeta(std::vector &beta, std::vector &Ap, int i, int N, int k) + void GCR::computeBeta(std::vector &beta, cvector_ref &Ap, int i, int N, int k) { std::vector Beta(N, 0.0); blas::block::cDotProduct(Beta, {Ap.begin() + i, Ap.begin() + i + N}, Ap[k]); // vectorized dot product @@ -36,14 +36,14 @@ namespace quda { for (int j = 0; j < N; j++) beta[(i + j) * n_krylov + k] = Beta[j]; } - void GCR::updateAp(std::vector &beta, std::vector &Ap, int begin, int size, int k) + void GCR::updateAp(std::vector &beta, cvector_ref &Ap, int begin, int size, int k) { std::vector beta_(size); for (int i = 0; i < size; i++) beta_[i] = -beta[(i + begin) * n_krylov + k]; blas::block::caxpy(beta_, {Ap.begin() + begin, Ap.begin() + begin + size}, Ap[k]); } - void GCR::orthoDir(std::vector &beta, std::vector &Ap, int k, int pipeline) + void GCR::orthoDir(std::vector &beta, cvector_ref &Ap, int k, int pipeline) { switch (pipeline) { case 0: // no kernel fusion @@ -64,9 +64,9 @@ namespace quda { { const int N = pipeline; for (int i=0; i0; r--) { @@ -93,7 +93,7 @@ namespace quda { } void GCR::updateSolution(ColorSpinorField &x, const std::vector &alpha, const std::vector &beta, - std::vector &gamma, int k, std::vector &p) + std::vector &gamma, int k, cvector_ref &p) { std::vector delta(k); @@ -109,10 +109,7 @@ namespace quda { matMdagM(DiracMdagM(matEig.Expose())), K(0), Kparam(param), - n_krylov(param.Nkrylov), - alpha(n_krylov), - beta(n_krylov * n_krylov), - gamma(n_krylov) + n_krylov(param.Nkrylov) { fillInnerSolverParam(Kparam, param); @@ -128,10 +125,7 @@ namespace quda { matMdagM(matEig.Expose()), K(nullptr), Kparam(param), - n_krylov(param.Nkrylov), - alpha(n_krylov), - beta(n_krylov * n_krylov), - gamma(n_krylov) + n_krylov(param.Nkrylov) { fillInnerSolverParam(Kparam, param); K = wrapExternalPreconditioner(K_); @@ -142,40 +136,63 @@ namespace quda { destroyDeflationSpace(); } - void GCR::create(ColorSpinorField &x, const ColorSpinorField &b) + void GCR::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { getProfile().TPSTART(QUDA_PROFILE_INIT); - ColorSpinorParam csParam(x); + ColorSpinorParam csParam(x[0]); csParam.create = QUDA_NULL_FIELD_CREATE; // create sloppy fields used for orthogonalization csParam.setPrecision(param.precision_sloppy); - resize(p, n_krylov + 1, QUDA_NULL_FIELD_CREATE, csParam); - resize(Ap, n_krylov, QUDA_NULL_FIELD_CREATE, csParam); + p.resize(n_krylov + 1); + Ap.resize(n_krylov); + for (auto &p_ : p) resize(p_, b.size(), QUDA_NULL_FIELD_CREATE, csParam); + for (auto &ap : Ap) resize(ap, b.size(), QUDA_NULL_FIELD_CREATE, csParam); csParam.setPrecision(param.precision); if (K || mixed()) { - r = ColorSpinorField(csParam); + resize(r, b.size(), csParam); } else { - r = p[0].create_alias(); + create_alias(r, p[0]); } csParam.setPrecision(param.precision_sloppy); if (!K) { - r_sloppy = p[0].create_alias(); + create_alias(r_sloppy, p[0]); + } else if (!mixed()) { + create_alias(r_sloppy, r); } else { - r_sloppy = mixed() ? ColorSpinorField(csParam) : r.create_alias(); + resize(r_sloppy, b.size(), csParam); } getProfile().TPSTOP(QUDA_PROFILE_INIT); + + alpha.resize(b.size()); + beta.resize(b.size()); + gamma.resize(b.size()); + for (auto i = 0u; i < b.size(); i++) { + alpha[i].resize(n_krylov); + beta[i].resize(n_krylov * n_krylov); + gamma[i].resize(n_krylov); + } + init = true; } } - void GCR::operator()(ColorSpinorField &x, ColorSpinorField &b) + cvector_ref GCR::get_residual() + { + if (!init) errorQuda("No residual vector present"); + if (param.compute_true_res) + return r; + else + return K ? r_sloppy : p[k_break]; + } + + void GCR::operator()(cvector_ref &x, cvector_ref &b) { if (n_krylov == 0) { // Krylov space is zero-dimensional so return doing no work @@ -189,15 +206,14 @@ namespace quda { if (param.deflate) { // Construct the eigensolver and deflation space if requested. if (param.eig_param.eig_type == QUDA_EIG_TR_LANCZOS || param.eig_param.eig_type == QUDA_EIG_BLK_TR_LANCZOS) { - constructDeflationSpace(b, matMdagM); + constructDeflationSpace(b[0], matMdagM); } else { // Use Arnoldi to inspect the space only and turn off deflation - constructDeflationSpace(b, mat); + constructDeflationSpace(b[0], mat); param.deflate = false; } if (deflate_compute) { // compute the deflation space. - getProfile().TPSTOP(QUDA_PROFILE_INIT); (*eig_solve)(evecs, evals); if (param.deflate) { // double the size of the Krylov space @@ -205,7 +221,6 @@ namespace quda { // populate extra memory with L/R singular vectors eig_solve->computeSVD(evecs, evals); } - getProfile().TPSTART(QUDA_PROFILE_INIT); deflate_compute = false; } if (recompute_evals) { @@ -215,8 +230,8 @@ namespace quda { } } - double b2 = blas::norm2(b); // norm sq of source - double r2; // norm sq of residual + vector b2 = blas::norm2(b); // norm sq of source + vector r2; // norm sq of residual // compute initial residual depending on whether we have an initial guess or not if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { @@ -240,20 +255,13 @@ namespace quda { } // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0) { - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - getProfile().TPSTOP(QUDA_PROFILE_INIT); - warningQuda("inverting on zero-field source\n"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; - return; - } else { - b2 = r2; - } + if (is_zero_src(x, b, b2)) { + getProfile().TPSTOP(QUDA_PROFILE_INIT); + return; } - double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop_hq = vector(b.size(), param.tol_hq); const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; @@ -264,8 +272,11 @@ namespace quda { const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance const int maxResIncreaseTotal = param.max_res_increase_total; - double heavy_quark_res = 0.0; // heavy quark residual - if(use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x,r).z); + std::vector heavy_quark_res(b.size()); // heavy quark residual + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } int resIncrease = 0; int resIncreaseTotal = 0; @@ -277,23 +288,21 @@ namespace quda { int total_iter = 0; int restart = 0; - double r2_old = r2; - double maxr_deflate = sqrt(r2); + auto r2_old = r2; + double maxr_deflate = sqrt(r2[0]); bool l2_converge = false; int pipeline = param.pipeline; - // Vectorized dot product only has limited support so work around - if (Ap[0].Location() == QUDA_CPU_FIELD_LOCATION || pipeline == 0) pipeline = 1; if (pipeline > n_krylov) pipeline = n_krylov; getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); getProfile().TPSTART(QUDA_PROFILE_COMPUTE); int k = 0; - int k_break = 0; + k_break = 0; PrintStats("GCR", total_iter+k, r2, b2, heavy_quark_res); - while ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) && total_iter < param.maxiter) { + while (!convergence(r2, heavy_quark_res, stop, stop_hq) && total_iter < param.maxiter) { if (K) { pushVerbosity(param.verbosity_precondition); @@ -305,16 +314,29 @@ namespace quda { matSloppy(Ap[k], p[k]); - orthoDir(beta, Ap, k, pipeline); + auto get_i = [](std::vector> &p, int i) { + vector_ref p_i; + p_i.reserve(p.size()); + for (auto &pi : p) p_i.push_back(pi[i]); + return p_i; + }; - double3 Apr = blas::cDotProductNormA(Ap[k], K ? r_sloppy : p[k]); + for (auto i = 0u; i < b.size(); i++) orthoDir(beta[i], get_i(Ap, i), k, pipeline); - gamma[k] = sqrt(Apr.z); // gamma[k] = Ap[k] - if (gamma[k] == 0.0) errorQuda("GCR breakdown"); - alpha[k] = Complex(Apr.x, Apr.y) / gamma[k]; // alpha = (1/|Ap|) * (Ap, r) + auto Apr = blas::cDotProductNormA(Ap[k], K ? r_sloppy : p[k]); + + for (auto i = 0u; i < b.size(); i++) { + gamma[i][k] = sqrt(Apr[i].z); // gamma[k] = Ap[k] + if (gamma[i][k] == 0.0) errorQuda("GCR breakdown"); + alpha[i][k] = Complex(Apr[i].x, Apr[i].y) / gamma[i][k]; // alpha = (1/|Ap|) * (Ap, r) + } // r -= (1/|Ap|^2) * (Ap, r) r, Ap *= 1/|Ap| - r2 = blas::cabxpyzAxNorm(1.0 / gamma[k], -alpha[k], Ap[k], K ? r_sloppy : p[k], K ? r_sloppy : p[k + 1]); + std::vector gamma_k_inv(b.size()); + for (auto i = 0u; i < gamma_k_inv.size(); i++) gamma_k_inv[i] = 1.0 / gamma[i][k]; + std::vector alpha_k(b.size()); + for (auto i = 0u; i < alpha_k.size(); i++) alpha_k[i] = -alpha[i][k]; + r2 = blas::cabxpyzAxNorm(gamma_k_inv, alpha_k, Ap[k], K ? r_sloppy : p[k], K ? r_sloppy : p[k + 1]); k++; total_iter++; @@ -323,16 +345,17 @@ namespace quda { // update since n_krylov or maxiter reached, converged or reliable update required // note that the heavy quark residual will by definition only be checked every n_krylov steps - if (k == n_krylov || total_iter == param.maxiter || (r2 < stop && !l2_converge) || sqrt(r2 / r2_old) < param.delta) { + if (k == n_krylov || total_iter == param.maxiter || (r2[0] < stop[0] && !l2_converge) + || sqrt(r2[0] / r2_old[0]) < param.delta) { // update the solution vector - updateSolution(x, alpha, beta, gamma, k, p); + for (auto i = 0u; i < b.size(); i++) updateSolution(x[i], alpha[i], beta[i], gamma[i], k, get_i(p, i)); if ( (r2 < stop || total_iter==param.maxiter) && param.sloppy_converge) break; mat(r, x); r2 = blas::xmyNorm(b, r); - if (param.deflate && sqrt(r2) < maxr_deflate * param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < maxr_deflate * param.tol_restart) { // Deflate: Hardcoded to SVD. eig_solve->deflateSVD(x, r, evecs, evals, true); @@ -340,17 +363,21 @@ namespace quda { mat(r, x); r2 = blas::xmyNorm(b, r); - maxr_deflate = sqrt(r2); + maxr_deflate = sqrt(r2[0]); } - if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } // break-out check if we have reached the limit of the precision if (r2 > r2_old) { resIncrease++; resIncreaseTotal++; - warningQuda("GCR: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), sqrt(r2_old), resIncreaseTotal); + warningQuda( + "GCR: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", + sqrt(r2[0]), sqrt(r2_old[0]), resIncreaseTotal); if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) { warningQuda("GCR: solver exiting due to too many true residual norm increases"); break; @@ -362,7 +389,7 @@ namespace quda { k_break = k; k = 0; - if ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) ) { + if (!convergence(r2, heavy_quark_res, stop, stop_hq)) { restart++; // restarting if residual is still too great PrintStats("GCR (restart)", restart, r2, b2, heavy_quark_res); @@ -371,7 +398,7 @@ namespace quda { r2_old = r2; // prevent ending the Krylov space prematurely if other convergence criteria not met - if (r2 < stop) l2_converge = true; + if (r2 < stop) l2_converge = true; } r2_old = r2; @@ -388,23 +415,19 @@ namespace quda { if (param.compute_true_res) { // Calculate the true residual mat(r, x); - double true_res = blas::xmyNorm(b, r); - param.true_res = sqrt(true_res / b2); - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) - param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x,r).z); - else - param.true_res_hq = 0.0; - //if (param.preserve_source == QUDA_PRESERVE_SOURCE_NO) blas::copy(b, r); - } else { - // reuse this when we add the get_residual method to GCR - if (0) blas::copy(b, K ? r_sloppy : p[k_break]); + auto true_r2 = blas::xmyNorm(b, r); + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(true_r2[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } } param.iter += total_iter; getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); - PrintSummary("GCR", total_iter, r2, b2, stop, param.tol_hq); + PrintSummary("GCR", total_iter, r2, b2, stop, stop_hq); } } // namespace quda diff --git a/lib/inv_gmresdr_quda.cpp b/lib/inv_gmresdr_quda.cpp index 8d2b323146..6853203dab 100644 --- a/lib/inv_gmresdr_quda.cpp +++ b/lib/inv_gmresdr_quda.cpp @@ -379,7 +379,7 @@ namespace quda { return (j - start_idx); } - void GMResDR::operator()(ColorSpinorField &x, ColorSpinorField &b) + void GMResDR::operator()(ColorSpinorField &x, const ColorSpinorField &b) { getProfile().TPSTART(QUDA_PROFILE_INIT); diff --git a/lib/inv_mr_quda.cpp b/lib/inv_mr_quda.cpp index df8f0f2dbc..1696cd4d42 100644 --- a/lib/inv_mr_quda.cpp +++ b/lib/inv_mr_quda.cpp @@ -20,32 +20,31 @@ namespace quda } } - void MR::create(ColorSpinorField &x, const ColorSpinorField &b) + void MR::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { - ColorSpinorParam csParam(b); - csParam.create = QUDA_NULL_FIELD_CREATE; - - r = ColorSpinorField(csParam); + if (!init || r.size() != b.size()) { + resize(r, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); // now allocate sloppy fields + ColorSpinorParam csParam(b[0]); + csParam.create = QUDA_NULL_FIELD_CREATE; csParam.setPrecision(param.precision_sloppy); - Ar = ColorSpinorField(csParam); - x_sloppy = ColorSpinorField(csParam); + resize(Ar, b.size(), csParam); + resize(x_sloppy, b.size(), csParam); - bool mixed = param.precision != param.precision_sloppy; - - if (!mixed) csParam.create = QUDA_REFERENCE_FIELD_CREATE; - csParam.v = r.data(); - r_sloppy = ColorSpinorField(csParam); + if (param.precision != param.precision_sloppy) { // mixed precision + resize(r_sloppy, b.size(), csParam); + } else { + create_alias(r_sloppy, r); + } init = true; } // init } - ColorSpinorField &MR::get_residual() + cvector_ref MR::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); @@ -53,7 +52,7 @@ namespace quda return r; } - void MR::operator()(ColorSpinorField &x, ColorSpinorField &b) + void MR::operator()(cvector_ref &x, cvector_ref &b) { if (param.maxiter == 0 || param.Nsteps == 0) { if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); @@ -64,11 +63,14 @@ namespace quda if (!param.is_preconditioner) getProfile().TPSTART(QUDA_PROFILE_COMPUTE); - double b2 = blas::norm2(b); // Save norm of b - double r2 = 0.0; // if zero source then we will exit immediately doing no work + vector b2 = blas::norm2(b); // Save norm of b + vector r2; + if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { mat(r, x); r2 = blas::xmyNorm(b, r); // r = b - Ax0 + for (auto i = 0u; i < b.size(); i++) + if (b2[i] == 0) b2[i] = r2[i]; } else { r2 = b2; blas::copy(r, b); @@ -76,18 +78,20 @@ namespace quda } blas::copy(r_sloppy, r); - // if invalid residual then convergence is set by iteration count only - double stop = param.residual_type == QUDA_INVALID_RESIDUAL ? 0.0 : b2 * param.tol * param.tol; + auto stop = stopping(param.tol, b2, param.residual_type); int iter = 0; int step = 0; bool converged = false; - PrintStats("MR", iter, r2, b2, 0.0); + PrintStats("MR", iter, r2, b2); while (!converged) { int k = 0; - double scale = 1.0; + vector scale(b.size(), 1.0); + vector scale_inv(b.size(), 1.0); + vector delta2(b.size(), param.delta * param.delta); + if ((node_parity + step) % 2 == 0 && param.schwarz_type == QUDA_MULTIPLICATIVE_SCHWARZ) { // for multiplicative Schwarz we alternate updates depending on node parity } else { @@ -95,31 +99,34 @@ namespace quda commGlobalReductionPush(param.global_reduction); // use local reductions for DD solver blas::zero(x_sloppy); // can get rid of this for a special first update kernel - double c2 = param.global_reduction == QUDA_BOOLEAN_TRUE ? r2 : blas::norm2(r); // c2 holds the initial r2 - scale = c2 > 0.0 ? sqrt(c2) : 1.0; - - // domain-wise normalization of the initial residual to prevent underflow - if (c2 > 0.0) { - blas::ax(1 / scale, r_sloppy); // can merge this with the prior copy - r2 = 1.0; // by definition by this is now true + auto c2 = param.global_reduction == QUDA_BOOLEAN_TRUE ? r2 : blas::norm2(r); // c2 holds the initial r2 + for (auto i = 0u; i < b.size(); i++) { + scale[i] = c2[i] > 0.0 ? sqrt(c2[i]) : 1.0; + scale_inv[i] = 1.0 / scale[i]; + // domain-wise normalization of the initial residual to prevent underflow + if (c2[i] > 0.0) r2[i] = 1.0; // by definition by this is now true } + blas::ax(scale_inv, r_sloppy); // can merge this with the prior copy - while (k < param.maxiter && r2 > param.delta * param.delta) { + while (k < param.maxiter && r2 > delta2) { matSloppy(Ar, r_sloppy); if (param.global_reduction) { - double4 Ar4 = blas::cDotProductNormAB(Ar, r_sloppy); - Complex alpha = Complex(Ar4.x, Ar4.y) / Ar4.z; - r2 = Ar4.w; - PrintStats("MR (inner)", iter, r2, b2, 0.0); + auto Ar4 = blas::cDotProductNormAB(Ar, r_sloppy); + vector alpha(b.size()); + for (auto i = 0u; i < b.size(); i++) { + alpha[i] = Complex(Ar4[i].x, Ar4[i].y) / Ar4[i].z; + r2[i] = Ar4[i].w; + } + PrintStats("MR (inner)", iter, r2, b2); // x += omega*alpha*r, r -= omega*alpha*Ar, r2 = blas::norm2(r) blas::caxpyXmaz(param.omega * alpha, r_sloppy, x_sloppy, Ar); } else { // doing local reductions so can make it asynchronous commAsyncReductionSet(true); - blas::cDotProductNormA(Ar, r_sloppy); + blas::cDotProductNormAB(Ar, r_sloppy); // omega*alpha is done in the kernel blas::caxpyXmazMR(param.omega, r_sloppy, x_sloppy, Ar); @@ -140,10 +147,10 @@ namespace quda if (compute_true_res) { mat(r, x); r2 = blas::xmyNorm(b, r); - param.true_res = sqrt(r2 / b2); + for (auto i = 0u; i < b2.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); converged = (step < param.Nsteps && r2 > stop) ? false : true; if (!converged) blas::copy(r_sloppy, r); - PrintStats("MR (restart)", iter, r2, b2, 0.0); + PrintStats("MR (restart)", iter, r2, b2); } else { blas::ax(scale, r_sloppy); r2 = blas::norm2(r_sloppy); @@ -153,7 +160,7 @@ namespace quda step++; } - PrintSummary("MR", iter, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); + PrintSummary("MR", iter, r2, b2, stopping(param.tol, b2, param.residual_type)); if (!param.is_preconditioner) { getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); diff --git a/lib/inv_mre.cpp b/lib/inv_mre.cpp index 2ef194ba4e..f8d078a677 100644 --- a/lib/inv_mre.cpp +++ b/lib/inv_mre.cpp @@ -106,8 +106,7 @@ namespace quda } // if operator hasn't already been applied then apply - if (apply_mat) - for (int i = 0; i < N; i++) mat(q[i], p[i]); + if (apply_mat) mat(q, p); // Solution coefficient vectors std::vector alpha(N); @@ -128,7 +127,7 @@ namespace quda // compute the residual only if we're going to print it ColorSpinorField r(b); for (auto &a : alpha) a = -a; - blas::caxpy(alpha, q, r); + blas::block::caxpy(alpha, q, r); printfQuda("MinResExt: N = %d, |res| / |src| = %e\n", N, sqrt(blas::norm2(r) / blas::norm2(b))); } diff --git a/lib/inv_pcg_quda.cpp b/lib/inv_pcg_quda.cpp index 4fb03527ba..57244a04fb 100644 --- a/lib/inv_pcg_quda.cpp +++ b/lib/inv_pcg_quda.cpp @@ -18,8 +18,8 @@ namespace quda using namespace blas; - PreconCG::PreconCG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m) : + PCG::PCG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, + const DiracMatrix &matEig, SolverParam ¶m) : Solver(mat, matSloppy, matPrecon, matEig, param), K(nullptr), Kparam(param) { fillInnerSolverParam(Kparam, param); @@ -30,8 +30,8 @@ namespace quda K = createPreconditioner(matPrecon, matPrecon, matPrecon, matEig, param, Kparam); } - PreconCG::PreconCG(const DiracMatrix &mat, Solver &K_, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m) : + PCG::PCG(const DiracMatrix &mat, Solver &K_, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, + const DiracMatrix &matEig, SolverParam ¶m) : Solver(mat, matSloppy, matPrecon, matEig, param), K(nullptr), Kparam(param) { fillInnerSolverParam(Kparam, param); @@ -39,7 +39,7 @@ namespace quda K = wrapExternalPreconditioner(K_); } - PreconCG::~PreconCG() + PCG::~PCG() { getProfile().TPSTART(QUDA_PROFILE_FREE); @@ -49,85 +49,80 @@ namespace quda getProfile().TPSTOP(QUDA_PROFILE_FREE); } - void PreconCG::create(ColorSpinorField &x, const ColorSpinorField &b) + void PCG::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { getProfile().TPSTART(QUDA_PROFILE_INIT); - ColorSpinorParam csParam(b); + ColorSpinorParam csParam(b[0]); - r = ColorSpinorField(b); - if (K) minvr = ColorSpinorField(b); + resize(r, b.size(), csParam); csParam.create = QUDA_ZERO_FIELD_CREATE; - y = ColorSpinorField(csParam); + resize(y, b.size(), csParam); // create sloppy fields csParam.setPrecision(param.precision_sloppy); csParam.create = QUDA_NULL_FIELD_CREATE; - Ap = ColorSpinorField(csParam); + resize(Ap, b.size(), csParam); - x_sloppy = (!mixed() || !param.use_sloppy_partial_accumulator) ? - x.create_alias() : ColorSpinorField(csParam); + if (!mixed() || !param.use_sloppy_partial_accumulator) { + create_alias(x_sloppy, x); + } else { + resize(x_sloppy, b.size(), csParam); + } - csParam.create = QUDA_COPY_FIELD_CREATE; - csParam.field = &r; - r_sloppy = !mixed() ? r.create_alias() : ColorSpinorField(csParam); + if (!mixed()) { + create_alias(r_sloppy, r); + } else { + resize(r_sloppy, b.size(), csParam); + } if (K) { - csParam.field = &minvr; - minvr_sloppy = !mixed() ? minvr.create_alias() : ColorSpinorField(csParam); + resize(minvr_sloppy, b.size(), csParam); // create preconditioner intermediates - csParam.create = QUDA_NULL_FIELD_CREATE; csParam.setPrecision(Kparam.precision); - r_pre = ColorSpinorField(csParam); + resize(r_pre, b.size(), csParam); // Create minvr_pre - minvr_pre = ColorSpinorField(csParam); + resize(minvr_pre, b.size(), csParam); } Np = (param.solution_accumulator_pipeline == 0 ? 1 : param.solution_accumulator_pipeline); if (Np < 0 || Np > 16) errorQuda("Invalid value %d for solution_accumulator_pipeline", Np); - csParam.create = QUDA_NULL_FIELD_CREATE; - csParam.setPrecision(param.precision_sloppy); - x_update_batch = XUpdateBatch(Np, K ? minvr_sloppy : r_sloppy, csParam); - getProfile().TPSTOP(QUDA_PROFILE_INIT); init = true; } } - void PreconCG::solve_and_collect(ColorSpinorField &x, ColorSpinorField &b, - cvector_ref &v_r, - int collect_miniter, double collect_tol) + void PCG::solve_and_collect(cvector_ref &x, cvector_ref &b, + cvector_ref &v_r, int collect_miniter, double collect_tol) { - if (K) K->train_param(*this, b); + if (K) K->train_param(*this, b[0]); - create(x, b); + if (v_r.size() && x.size() > 1) errorQuda("Collect not supported for multi-RHS PCG"); getProfile().TPSTART(QUDA_PROFILE_INIT); // whether to select alternative reliable updates bool alternative_reliable = param.use_alternative_reliable; - double b2 = blas::norm2(b); + auto b2 = blas::norm2(b); // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0 && param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + if (is_zero_src(x, b, b2)) { getProfile().TPSTOP(QUDA_PROFILE_INIT); - warningQuda("Warning: inverting on zero-field source"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; return; } + create(x, b); + if (param.deflate) { // Construct the eigensolver and deflation space if requested. - constructDeflationSpace(b, matEig); + constructDeflationSpace(b[0], matEig); if (deflate_compute) { // compute the deflation space. (*eig_solve)(evecs, evals); @@ -139,22 +134,23 @@ namespace quda } } - double Anorm = 0; + double Anorm = 0.0; // for alternative reliable updates if (alternative_reliable) { // estimate norm for reliable updates - mat(r, b); - Anorm = sqrt(blas::norm2(r) / b2); + mat(r[0], b[0]); + Anorm = sqrt(norm2(r[0]) / b2[0]); } // compute initial residual - double r2 = 0.0; + vector r2(b.size(), 0.0); if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { // Compute r = b - A * x mat(r, x); r2 = blas::xmyNorm(b, r); - if (b2 == 0) b2 = r2; + for (auto i = 0u; i < b.size(); i++) + if (b2[i] == 0) b2[i] = r2[i]; // y contains the original guess. blas::copy(y, x); } else { @@ -171,34 +167,44 @@ namespace quda } blas::zero(x); - if (&x != &x_sloppy) blas::zero(x_sloppy); + if (param.use_sloppy_partial_accumulator) blas::zero(x_sloppy); + if (r_sloppy[0].Precision() != r[0].Precision()) blas::copy(r_sloppy, r); + + auto csParam(r_sloppy[0]); + std::vector x_update_batch(b.size()); + for (auto i = 0u; i < b.size(); i++) + x_update_batch[i] = XUpdateBatch(Np, K ? minvr_sloppy[i] : r_sloppy[i], csParam); const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; if (K) { - r_pre = r_sloppy; + blas::copy(r_pre, r_sloppy); pushVerbosity(param.verbosity_precondition); (*K)(minvr_pre, r_pre); popVerbosity(); - minvr_sloppy = minvr_pre; + blas::copy(minvr_sloppy, minvr_pre); } getProfile().TPSTOP(QUDA_PROFILE_INIT); getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); - double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver - double heavy_quark_res = 0.0; // heavy quark residual - if (use_heavy_quark_res) heavy_quark_res = sqrt(HeavyQuarkResidualNorm(x, r).z); + auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop_hq = std::vector(b.size(), param.tol_hq); - double beta = 0.0; - double pAp; - double rMinvr = 0; - double rMinvr_old = 0.0; - double r_new_Minvr_old = 0.0; - double r2_old = 0; - r2 = norm2(r); + std::vector heavy_quark_res(b.size(), 0.0); // heavy quark residual + if (use_heavy_quark_res) { + auto hq = HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } + + std::vector beta(b.size(), 0.0); + std::vector pAp(b.size(), 0.0); + std::vector rMinvr(b.size(), 0.0); + std::vector rMinvr_old(b.size(), 0.0); + std::vector r_new_Minvr_old(b.size(), 0.0); + std::vector r2_old(b.size(), 0.0); - if (K) rMinvr = reDotProduct(r_sloppy, minvr_sloppy); + if (K) { rMinvr = reDotProduct(r_sloppy, minvr_sloppy); } getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); getProfile().TPSTART(QUDA_PROFILE_COMPUTE); @@ -222,44 +228,62 @@ namespace quda ru_params.hqmaxresIncrease = param.max_hq_res_increase; ru_params.hqmaxresRestartTotal = param.max_hq_res_restart_total; - ReliableUpdates ru(ru_params, r2); + ReliableUpdates ru(ru_params, r2[0]); - bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + bool converged = convergence(r2, heavy_quark_res, stop, stop_hq); - while (!converged && k < param.maxiter) { + auto get_p = [](std::vector &x_update_batch, bool next = false) { + vector_ref p; + p.reserve(x_update_batch.size()); + for (auto &x : x_update_batch) p.push_back(next ? x.get_next_field() : x.get_current_field()); + return p; + }; - matSloppy(Ap, x_update_batch.get_current_field()); + auto get_alpha = [](std::vector &x_update_batch) { + vector alpha; + alpha.reserve(x_update_batch.size()); + for (auto &x : x_update_batch) alpha.push_back(x.get_current_alpha()); + return alpha; + }; + + while (!converged && k < param.maxiter) { + auto p = get_p(x_update_batch); + auto p_next = get_p(x_update_batch, true); + matSloppy(Ap, p); - double sigma; // alternative reliable updates, if (alternative_reliable) { - double3 pAppp = blas::cDotProductNormA(x_update_batch.get_current_field(), Ap); - pAp = pAppp.x; - ru.update_ppnorm(pAppp.z); + auto pAppp = blas::cDotProductNormA(p, Ap); + for (auto i = 0u; i < b.size(); i++) pAp[i] = pAppp[i].x; + ru.update_ppnorm(pAppp[0].z); } else { - pAp = reDotProduct(x_update_batch.get_current_field(), Ap); + pAp = reDotProduct(p, Ap); } - x_update_batch.get_current_alpha() = (K) ? rMinvr / pAp : r2 / pAp; - double2 cg_norm = axpyCGNorm(-x_update_batch.get_current_alpha(), Ap, r_sloppy); + for (auto i = 0u; i < b.size(); i++) + x_update_batch[i].get_current_alpha() = K ? rMinvr[i] / pAp[i] : r2[i] / pAp[i]; + + auto cg_norm = axpyCGNorm(-get_alpha(x_update_batch), Ap, r_sloppy); // r --> r - alpha*A*p r2_old = r2; - r2 = cg_norm.x; - sigma = cg_norm.y >= 0.0 ? cg_norm.y : r2; // use r2 if (r_k+1, r_k-1 - r_k) breaks + vector sigma(b.size()); + for (auto i = 0u; i < b.size(); i++) { + r2[i] = cg_norm[i].x; + sigma[i] = cg_norm[i].y >= 0.0 ? cg_norm[i].y : r2[i]; // use r2 if (r_k+1, r_k-1 - r_k) breaks + } if (K) rMinvr_old = rMinvr; - ru.update_rNorm(sqrt(r2)); - - ru.evaluate(r2_old); + ru.update_rNorm(sqrt(r2[0])); + ru.evaluate(r2_old[0]); // force a reliable update if we are within target tolerance (only if doing reliable updates) - if (convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol) ru.set_updateX(); + if (convergence(r2, heavy_quark_res, stop, stop_hq) && param.delta >= param.tol) ru.set_updateX(); - if (collect > 0 && k > collect_miniter && r2 < collect_tol * collect_tol * b2) { - v_r[v_r.size() - collect] = r_sloppy; - logQuda(QUDA_VERBOSE, "Collecting r %2d: r2 / b2 = %12.8e, k = %5d\n", collect, sqrt(r2 / b2), k); + if (collect > 0 && k > collect_miniter && r2[0] < collect_tol * collect_tol * b2[0]) { + blas::copy(v_r[v_r.size() - collect], r_sloppy); + logQuda(QUDA_VERBOSE, "Collecting r %2d: r2 / b2 = %12.8e, k = %5d\n", collect, sqrt(r2[0] / b2[0]), k); collect--; } @@ -278,35 +302,36 @@ namespace quda minvr_sloppy = minvr_pre; rMinvr = reDotProduct(r_sloppy, minvr_sloppy); - beta = (rMinvr - r_new_Minvr_old) / rMinvr_old; + for (auto i = 0u; i < b.size(); i++) beta[i] = (rMinvr[i] - r_new_Minvr_old[i]) / rMinvr_old[i]; } else { - beta = sigma / r2_old; // use the alternative beta computation + for (auto i = 0u; i < b.size(); i++) beta[i] = sigma[i] / r2_old[i]; // use the alternative beta computation } if (Np == 1) { - axpyZpbx(x_update_batch.get_current_alpha(), x_update_batch.get_current_field(), x_sloppy, - K ? minvr_sloppy : r_sloppy, beta); + axpyZpbx(get_alpha(x_update_batch), p, x_sloppy, K ? minvr_sloppy : r_sloppy, beta); } else { - if (x_update_batch.is_container_full()) { x_update_batch.accumulate_x(x_sloppy); } - blas::xpayz(K ? minvr_sloppy : r_sloppy, beta, x_update_batch.get_current_field(), - x_update_batch.get_next_field()); + for (auto i = 0u; i < b.size(); i++) { + if (x_update_batch[i].is_container_full()) x_update_batch[i].accumulate_x(x_sloppy[i]); + } + blas::xpayz(K ? minvr_sloppy : r_sloppy, beta, p, p_next); } - ru.accumulate_norm(x_update_batch.get_current_alpha()); + ru.accumulate_norm(get_alpha(x_update_batch)[0]); } else { // reliable update // Now that we are performing reliable update, need to update x with the p's that have // not been used yet - x_update_batch.accumulate_x(x_sloppy); - x_update_batch.reset_next(); - + for (auto i = 0u; i < b.size(); i++) { + x_update_batch[i].accumulate_x(x_sloppy[i]); + x_update_batch[i].reset_next(); + } xpy(x_sloppy, y); // y += x // Now compute r mat(r, y); r2 = xmyNorm(b, r); - if (param.deflate && sqrt(r2) < ru.maxr_deflate * param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < ru.maxr_deflate * param.tol_restart) { // Deflate and accumulate to solution vector eig_solve->deflate(y, r, evecs, evals, true); @@ -314,7 +339,7 @@ namespace quda mat(r, y); r2 = blas::xmyNorm(b, r); - ru.update_maxr_deflate(r2); + ru.update_maxr_deflate(r2[0]); } copy(r_sloppy, r); @@ -322,11 +347,13 @@ namespace quda bool L2breakdown = false; double L2breakdown_eps = 0; - if (ru.reliable_break(r2, stop, L2breakdown, L2breakdown_eps)) { break; } + if (ru.reliable_break(r2[0], stop[0], L2breakdown, L2breakdown_eps)) { break; } - ru.update_norm(r2, y); + ru.update_norm(r2[0], y[0]); + ru.reset(r2[0]); - ru.reset(r2); + auto p = get_p(x_update_batch); + auto p_next = get_p(x_update_batch, true); if (K) { // can fuse these two kernels @@ -341,31 +368,35 @@ namespace quda minvr_sloppy = minvr_pre; rMinvr = reDotProduct(r_sloppy, minvr_sloppy); - beta = (rMinvr - r_new_Minvr_old) / rMinvr_old; + for (auto i = 0u; i < b.size(); i++) beta[i] = (rMinvr[i] - r_new_Minvr_old[i]) / rMinvr_old[i]; } else { // standard CG - no preconditioning // explicitly restore the orthogonality of the gradient vector - double rp = reDotProduct(r_sloppy, x_update_batch.get_current_field()) / (r2); - axpy(-rp, r_sloppy, x_update_batch.get_current_field()); + auto rp = cDotProduct(r_sloppy, p); + for (auto i = 0u; i < b.size(); i++) rp[i] /= r2[i]; + caxpy(-rp, r_sloppy, p); - beta = r2 / r2_old; + for (auto i = 0u; i < b.size(); i++) beta[i] = r2[i] / r2_old[i]; } - xpayz(K ? minvr_sloppy : r_sloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_next_field()); + xpayz(K ? minvr_sloppy : r_sloppy, beta, p, p_next); } - ++k; + k++; PrintStats("PCG", k, r2, b2, heavy_quark_res); - converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + converged = convergence(r2, heavy_quark_res, stop, stop_hq); + // if we have converged and need to update any trailing solutions - if ((converged || k == param.maxiter) && ru.steps_since_reliable > 0 && !x_update_batch.is_container_full()) { - x_update_batch.accumulate_x(x_sloppy); - } + for (auto i = 0u; i < b.size(); i++) { + if ((converged || k == param.maxiter) && ru.steps_since_reliable > 0 && !x_update_batch[i].is_container_full()) { + x_update_batch[i].accumulate_x(x_sloppy[i]); + } - if (ru.steps_since_reliable == 0) { - x_update_batch.reset(); - } else { - ++x_update_batch; + if (ru.steps_since_reliable == 0) { + x_update_batch[i].reset(); + } else { + ++x_update_batch[i]; + } } } @@ -384,8 +415,8 @@ namespace quda // compute the true residual mat(r, x); - double true_res = xmyNorm(b, r); - param.true_res = sqrt(true_res / b2); + auto true_res = xmyNorm(b, r); + for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(true_res[i] / b2[i]); getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); } diff --git a/lib/inv_sd_quda.cpp b/lib/inv_sd_quda.cpp index a715f89fc0..ba74e5059c 100644 --- a/lib/inv_sd_quda.cpp +++ b/lib/inv_sd_quda.cpp @@ -11,55 +11,61 @@ namespace quda { SD::SD(const DiracMatrix &mat, SolverParam ¶m) : Solver(mat, mat, mat, mat, param) { } - void SD::create(ColorSpinorField &x, const ColorSpinorField &b) + void SD::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { - ColorSpinorParam csParam(b); - csParam.create = QUDA_NULL_FIELD_CREATE; - r = ColorSpinorField(csParam); - Ar = ColorSpinorField(csParam); + if (!init || r.size() != b.size()) { + resize(r, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); + resize(Ar, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); init = true; } } - ColorSpinorField &SD::get_residual() + cvector_ref SD::get_residual() { if (!init) errorQuda("No residual vector present"); return r; } - void SD::operator()(ColorSpinorField &x, ColorSpinorField &b) + void SD::operator()(cvector_ref &x, cvector_ref &b) { commGlobalReductionPush(param.global_reduction); create(x, b); - double b2 = blas::norm2(b); - double r2; + vector b2 = blas::norm2(b); + vector r2; + + // Check to see that we're not trying to invert on a zero-field source + if (is_zero_src(x, b, b2)) return; + if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { // Compute the true residual mat(r, x); r2 = blas::xmyNorm(b, r); + for (auto i = 0u; i < b.size(); i++) + if (b2[i] == 0) b2[i] = r2[i]; } else { blas::zero(x); blas::copy(r, b); r2 = b2; } - // if invalid residual then convergence is set by iteration count only - double stop = param.residual_type == QUDA_INVALID_RESIDUAL ? 0.0 : b2 * param.tol * param.tol; + auto stop = stopping(param.tol, b2, param.residual_type); int res_increase = 0; int k = 0; while (k < param.maxiter) { mat(Ar, r); - double3 rAr = blas::cDotProductNormA(r, Ar); - auto alpha = rAr.z / rAr.x; - r2 = rAr.z; // this is r2 from the prior iteration + auto rAr = blas::cDotProductNormA(r, Ar); + vector alpha(b.size()); + for (auto i = 0u; i < b.size(); i++) { + alpha[i] = rAr[i].z / rAr[i].x; + r2[i] = rAr[i].z; // this is r2 from the prior iteration + } - PrintStats("SD", k, r2, b2, 0.0); + PrintStats("SD", k, r2, b2); if (r2 < stop) { mat(r, x); @@ -80,11 +86,11 @@ namespace quda { if (param.compute_true_res) { // Compute the true residual mat(r, x); - double true_r2 = blas::xmyNorm(b, r); - PrintSummary("SD", k, true_r2, b2, 0.0, 0.0); - param.true_res = sqrt(true_r2 / b2); + auto true_r2 = blas::xmyNorm(b, r); + PrintSummary("SD", k, true_r2, b2, stop); + for (auto i = 0u; i < b2.size(); i++) param.true_res[i] = sqrt(true_r2[i] / b2[i]); } else { - PrintSummary("SD", k, r2, b2, 0.0, 0.0); + PrintSummary("SD", k, r2, b2, stop); } commGlobalReductionPop(); diff --git a/lib/madwf_ml.cpp b/lib/madwf_ml.cpp index adb7ee323f..1df1ebd706 100644 --- a/lib/madwf_ml.cpp +++ b/lib/madwf_ml.cpp @@ -92,9 +92,9 @@ namespace quda if (getVerbosity() >= QUDA_VERBOSE) { printfQuda("Generating Null Space Vectors ... \n"); } spinorNoise(null_b, rng, QUDA_NOISE_GAUSS); - std::vector B(16); csParam.setPrecision(prec_precondition); - for (auto &pB : B) { pB = ColorSpinorField(csParam); } + std::vector B; + resize(B, 16, csParam); getProfile().TPSTOP(QUDA_PROFILE_INIT); null.solve_and_collect(null_x, null_b, B, param.madwf_null_miniter, param.madwf_null_tol); diff --git a/lib/milc_interface.cpp b/lib/milc_interface.cpp index 131c57fde0..23473eaa4b 100644 --- a/lib/milc_interface.cpp +++ b/lib/milc_interface.cpp @@ -1209,8 +1209,8 @@ void qudaInvert(int external_precision, int quda_precision, double mass, QudaInv // return the number of iterations taken by the inverter *num_iters = invertParam.iter; - *final_residual = invertParam.true_res; - *final_fermilab_residual = invertParam.true_res_hq; + *final_residual = invertParam.true_res[0]; + *final_fermilab_residual = invertParam.true_res_hq[0]; if (!create_quda_gauge) invalidateGaugeQuda(); @@ -1424,10 +1424,19 @@ void qudaInvertMsrc(int external_precision, int quda_precision, double mass, Qud host_free(sln_pointer); host_free(src_pointer); - // return the number of iterations taken by the inverter + // The conventions for num_iters, final_residual, and final_fermilab_residual are taken from the + // convention in `generic_ks/d_congrad5_fn_milc.c` (commit 414fb31). Here, a block solve + // is emulated as a series of sequential solves. Each individual solve overrides the + // final tolerance and iteration counts from the previous solve. Therefore, num_iters + // as well as the tolerances come from the last solve. + + // invertParam.iter is the total number of iterations for the block solver, which is ~= + // to the number of iterations the last rhs would take. *num_iters = invertParam.iter; - *final_residual = invertParam.true_res; - *final_fermilab_residual = invertParam.true_res_hq; + + // MILC only cares about a single residual, which happens to be the last one as described above. + *final_residual = invertParam.true_res[num_src - 1]; + *final_fermilab_residual = invertParam.true_res_hq[num_src - 1]; if (!create_quda_gauge) invalidateGaugeQuda(); @@ -1521,8 +1530,8 @@ void qudaEigCGInvert(int external_precision, int quda_precision, double mass, Qu // return the number of iterations taken by the inverter *num_iters = invertParam.iter; - *final_residual = invertParam.true_res; - *final_fermilab_residual = invertParam.true_res_hq; + *final_residual = invertParam.true_res[0]; + *final_fermilab_residual = invertParam.true_res_hq[0]; if (!create_quda_gauge && last_rhs_flag) invalidateGaugeQuda(); @@ -2139,6 +2148,7 @@ void milcSetMultigridEigParam(QudaEigParam &mg_eig_param, mgInputStruct &input_s mg_eig_param.n_kr = input_struct.deflate_n_kr; // mg_eig_n_kr[level]; mg_eig_param.n_conv = input_struct.nvec[level]; mg_eig_param.n_ev_deflate = -1; // deflate everything that converged + mg_eig_param.compute_evals_batch_size = (input_struct.nvec[level] % 16 == 0) ? 16 : 1; // compute the eigenvalues in appropriate batches mg_eig_param.batched_rotate = 0; // mg_eig_batched_rotate[level]; mg_eig_param.require_convergence = QUDA_BOOLEAN_TRUE; // mg_eig_require_convergence[level] ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; @@ -2319,6 +2329,7 @@ void milcSetMultigridParam(milcMultigridPack *mg_pack, QudaPrecision host_precis mg_param.setup_maxiter_refresh[i] = 0; // setup_maxiter_refresh[i]; mg_param.n_vec[i] = (i == 0) ? ((input_struct.optimized_kd == QUDA_TRANSFER_COARSE_KD) ? 24 : 3) : input_struct.nvec[i]; + mg_param.n_vec_batch[i] = (i == 0) ? 1 : (mg_param.n_vec[i] % 16 == 0 ? 16 : 1); mg_param.n_block_ortho[i] = 2; // n_block_ortho[i]; // number of times to Gram-Schmidt mg_param.precision_null[i] = input_struct.preconditioner_precision; // precision to store the null-space basis mg_param.smoother_halo_precision[i] @@ -2707,8 +2718,8 @@ void qudaInvertMG(int external_precision, int quda_precision, double mass, QudaI // return the number of iterations taken by the inverter *num_iters = invertParam.iter; - *final_residual = invertParam.true_res; - *final_fermilab_residual = invertParam.true_res_hq; + *final_residual = invertParam.true_res[0]; + *final_fermilab_residual = invertParam.true_res_hq[0]; if (!create_quda_gauge) invalidateGaugeQuda(); @@ -3004,8 +3015,8 @@ void qudaCloverInvert(int external_precision, invertQuda(solution, source, &invertParam); *num_iters = invertParam.iter; - *final_residual = invertParam.true_res; - *final_fermilab_residual = invertParam.true_res_hq; + *final_residual = invertParam.true_res[0]; + *final_fermilab_residual = invertParam.true_res_hq[0]; if (clover || cloverInverse) qudaFreeCloverField(); if (link) qudaFreeGaugeField(); @@ -3083,8 +3094,8 @@ void qudaEigCGCloverInvert(int external_precision, int quda_precision, double ka if (last_rhs_flag) destroyDeflationQuda(df_preconditioner); *num_iters = invertParam.iter; - *final_residual = invertParam.true_res; - *final_fermilab_residual = invertParam.true_res_hq; + *final_residual = invertParam.true_res[0]; + *final_fermilab_residual = invertParam.true_res_hq[0]; if ( (clover || cloverInverse) && last_rhs_flag) qudaFreeCloverField(); if (link && last_rhs_flag) qudaFreeGaugeField(); @@ -3151,7 +3162,7 @@ void qudaCloverMultishiftInvert(int external_precision, int quda_precision, int } invertQuda(solutionArray[0], source, &invertParam); - *final_residual = invertParam.true_res; + *final_residual = invertParam.true_res[0]; } else { invertMultiShiftQuda(solutionArray, source, &invertParam); for (int i=0; i state_history; + static std::vector state_history; /** @brief Return the time period for the monitor measurements. @@ -77,6 +77,10 @@ namespace quda while (is_running.load()) { auto state = device::get_state(); state_history.push_back(state); + + // periodically reserve larger state size to avoid push_back cost + if (state_history.size() % 100000 == 0) state_history.reserve(state_history.size() + 100000); + std::this_thread::sleep_for(get_period()); } } @@ -98,6 +102,8 @@ namespace quda if (is_enabled()) { warningQuda("Enabling device monitoring"); + // pre-reserve state_history size to avoid push_back cost + state_history.reserve(10000); start_time = std::chrono::high_resolution_clock::now(); try { // spawn monitoring thread and release @@ -187,5 +193,42 @@ namespace quda monitor_file.close(); } + size_t size() { return is_enabled() ? state_history.size() : 0; } + + state_t mean(size_t start, size_t end) + { + state_t mean; + double last_power = 0.0; + std::chrono::time_point last_time; + + if (start > 0 && end > start) { + auto start_time = state_history[start - 1].time; + auto end_time = state_history[end - 1].time; + for (auto i = start; i < end; i++) { + auto &state = state_history[i]; + if (i - start > 0) { + std::chrono::duration diff = state.time - last_time; + + // potential for non-uniform samples distribution, so integrate rather than sum + mean.power += state.power * diff.count(); + mean.temp += state.temp * diff.count(); + mean.clock += state.clock * diff.count(); + + // trapezoidal integration to compute energy + mean.energy += 0.5 * (state.power + last_power) * diff.count(); + } + last_power = state.power; + last_time = state.time; + } + + std::chrono::duration duration = end_time - start_time; + mean.power /= duration.count(); + mean.temp /= duration.count(); + mean.clock /= duration.count(); + } + + return mean; + } + } // namespace monitor } // namespace quda diff --git a/lib/multigrid.cpp b/lib/multigrid.cpp index a327717e84..5241fdeabc 100644 --- a/lib/multigrid.cpp +++ b/lib/multigrid.cpp @@ -13,8 +13,6 @@ namespace quda using namespace blas; - static bool debug = false; - MG::MG(MGParam ¶m) : Solver(*param.matResidual, *param.matSmooth, *param.matSmoothSloppy, *param.matSmoothSloppy, param), param(param), @@ -46,7 +44,7 @@ namespace quda } if (param.B[0].Nspin() == 1) csParam.gammaBasis = param.B[0].GammaBasis(); // hack for staggered to avoid unnecessary basis checks - r = ColorSpinorField(csParam); + resize(r, 1, csParam); } rng = new RNG(param.B[0], 1234); @@ -58,8 +56,8 @@ namespace quda // Initializing to random vectors for (int i = 0; i < (int)param.B.size(); i++) { - spinorNoise(r, *rng, QUDA_NOISE_UNIFORM); - param.B[i] = r; + spinorNoise(r[0], *rng, QUDA_NOISE_UNIFORM); + param.B[i] = r[0]; } } if (param.mg_global.num_setup_iter[param.level] > 0) { @@ -132,14 +130,18 @@ namespace quda param.mg_global.geo_block_size[param.level][i] = param.geoBlockSize[i]; // create coarse residual vector if not already created in verify() - if (r_coarse.empty()) - r_coarse = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r.Precision(), - param.mg_global.location[param.level + 1]); + if (r_coarse.empty()) { + r_coarse.resize(1); + r_coarse[0] = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r[0].Precision(), + param.mg_global.location[param.level + 1]); + } // create coarse solution vector if not already created in verify() - if (x_coarse.empty()) - x_coarse = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r.Precision(), - param.mg_global.location[param.level + 1]); + if (x_coarse.empty()) { + x_coarse.resize(1); + x_coarse[0] = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r[0].Precision(), + param.mg_global.location[param.level + 1]); + } int nVec_coarse = std::max(param.Nvec, param.mg_global.n_vec[param.level + 1]); B_coarse.resize(nVec_coarse); @@ -406,7 +408,8 @@ namespace quda diracParam.type = QUDA_COARSE_DIRAC; diracParam.halo_precision = param.mg_global.precision_null[param.level]; diracParam.setup_use_mma = param.mg_global.setup_use_mma[param.level]; - diracParam.dslash_use_mma = param.mg_global.dslash_use_mma[param.level]; + // level + 1 since this is for the coarse grid + diracParam.dslash_use_mma = param.mg_global.dslash_use_mma[param.level + 1]; diracParam.allow_truncation = (param.mg_global.allow_truncation == QUDA_BOOLEAN_TRUE) ? true : false; diracCoarseResidual = new DiracCoarse(diracParam, param.setup_location == QUDA_CUDA_FIELD_LOCATION ? true : false, @@ -660,7 +663,7 @@ namespace quda param_coarse_solver->verbosity_precondition = param.mg_global.verbosity[param.level+1]; // preconditioned solver wrapper is uniform precision - param_coarse_solver->precision = r_coarse.Precision(); + param_coarse_solver->precision = r_coarse[0].Precision(); param_coarse_solver->precision_sloppy = param_coarse_solver->precision; param_coarse_solver->precision_precondition = param_coarse_solver->precision_sloppy; @@ -699,9 +702,9 @@ namespace quda // Run a dummy solve so that the deflation space is constructed and computed if needed during the MG setup, // or the eigenvalues are recomputed during transfer. - spinorNoise(r_coarse, *coarse->rng, QUDA_NOISE_UNIFORM); + spinorNoise(r_coarse[0], *coarse->rng, QUDA_NOISE_UNIFORM); param_coarse_solver->maxiter = 1; // do a single iteration on the dummy solve - (*coarse_solver)(x_coarse, r_coarse); + (*coarse_solver)(x_coarse[0], r_coarse[0]); setOutputPrefix(prefix); // restore since we just popped back from coarse grid param_coarse_solver->maxiter = param.mg_global.coarse_solver_maxiter[param.level + 1]; } @@ -760,8 +763,9 @@ namespace quda { pushLevel(param.level); - QudaPrecision prec = (param.mg_global.precision_null[param.level] < r.Precision()) ? - param.mg_global.precision_null[param.level] : r.Precision(); + QudaPrecision prec = (param.mg_global.precision_null[param.level] < r[0].Precision()) ? + param.mg_global.precision_null[param.level] : + r[0].Precision(); // may want to revisit this---these were relaxed for cases where ghost_precision < precision // these were set while hacking in tests of quarter precision ghosts @@ -776,12 +780,12 @@ namespace quda // temporary fields used for verification std::vector fine_tmp(param.Nvec); - ColorSpinorParam fine_param(r); + ColorSpinorParam fine_param(r[0]); fine_param.create = QUDA_NULL_FIELD_CREATE; for (auto &f : fine_tmp) f = ColorSpinorField(fine_param); std::vector coarse_tmp(param.Nvec); - ColorSpinorParam coarse_param(r_coarse); + ColorSpinorParam coarse_param(r_coarse[0]); coarse_param.create = QUDA_NULL_FIELD_CREATE; for (auto &c : coarse_tmp) c = ColorSpinorField(coarse_param); @@ -789,6 +793,9 @@ namespace quda auto &tmp2 = fine_tmp[1]; auto &tmp_coarse = coarse_tmp[0]; + vector B_norm; + if (param.transfer_type == QUDA_TRANSFER_AGGREGATE) B_norm = norm2(param.B); + // No need to check (projector) v_k for staggered case if (param.transfer_type == QUDA_TRANSFER_AGGREGATE) { @@ -799,19 +806,21 @@ namespace quda transfer->R(coarse_tmp, param.B); transfer->P(fine_tmp, coarse_tmp); + auto max_deviation = blas::max_deviation(param.B, fine_tmp); + auto deviation = xmyNorm(param.B, fine_tmp); + auto coarse_norm = norm2(coarse_tmp); + auto fine_norm = norm2(coarse_tmp); for (auto i = 0; i < param.Nvec; i++) { - auto max_deviation = blas::max_deviation(param.B[i], fine_tmp[i]); - auto l2_deviation = sqrt(xmyNorm(param.B[i], fine_tmp[i]) / norm2(param.B[i])); - - logQuda(QUDA_VERBOSE, - "Vector %d: L2 norms v_k = %e P^\\dagger v_k = %e (1 - P P^\\dagger) v_k = %e; Deviations: L2 relative = %e, max = %e\n", - i, norm2(param.B[i]), norm2(coarse_tmp[i]), norm2(fine_tmp[i]), l2_deviation, max_deviation[0]); + auto l2_deviation = sqrt(deviation[i]) / B_norm[i]; + logQuda( + QUDA_VERBOSE, "Vector %d: L2 norms v_k = %e P^\\dagger v_k = %e (1 - P P^\\dagger) v_k = %e; Deviations: L2 relative = %e, max = %e\n", + i, B_norm[i], coarse_norm[i], fine_norm[i], l2_deviation, max_deviation[i][0]); if (check_deviation(l2_deviation, tol)) errorQuda("k=%d orthonormality failed: L2 relative deviation %e > %e", i, l2_deviation, tol); - if (check_deviation(max_deviation[0], tol)) - errorQuda("k=%d orthonormality failed: max deviation %e > %e", i, max_deviation[0], tol); + if (check_deviation(max_deviation[i][0], tol)) + errorQuda("k=%d orthonormality failed: max deviation %e > %e", i, max_deviation[i][0], tol); } - for (auto &f : fine_tmp) f.GammaBasis(r.GammaBasis()); // restore basis + for (auto &f : fine_tmp) f.GammaBasis(r[0].GammaBasis()); // restore basis // the oblique check if (param.mg_global.run_oblique_proj_check) { @@ -823,14 +832,14 @@ namespace quda logQuda(QUDA_SUMMARIZE, "Checking 1 > || (1 - DP(P^dagDP)P^dag) v_k || / || v_k || for %d vectors\n", param.Nvec); for (int i = 0; i < param.Nvec; i++) { - transfer->R(r_coarse, param.B[i]); - (*coarse_solver)(x_coarse, r_coarse); // this needs to be an exact solve to pass + transfer->R(r_coarse[0], param.B[i]); + (*coarse_solver)(x_coarse[0], r_coarse[0]); // this needs to be an exact solve to pass setOutputPrefix(prefix); // restore prefix after return from coarse grid - transfer->P(tmp2, x_coarse); + transfer->P(tmp2, x_coarse[0]); (*param.matResidual)(tmp1, tmp2); tmp2 = param.B[i]; - logQuda(QUDA_SUMMARIZE, "Vector %d: norms %e %e\n", i, norm2(param.B[i]), norm2(tmp1)); - logQuda(QUDA_SUMMARIZE, "relative residual = %e\n", sqrt(xmyNorm(tmp2, tmp1) / norm2(param.B[i]))); + logQuda(QUDA_SUMMARIZE, "Vector %d: norms %e %e\n", i, B_norm[i], norm2(tmp1)); + logQuda(QUDA_SUMMARIZE, "relative residual = %e\n", sqrt(xmyNorm(tmp2, tmp1) / B_norm[i])); } sprintf(prefix, "MG level %d (%s): ", param.level + 1, param.location == QUDA_CUDA_FIELD_LOCATION ? "GPU" : "CPU"); @@ -844,37 +853,41 @@ namespace quda for (int i=0; iR(r_coarse, param.B[i]); - (*coarse)(x_coarse, r_coarse); // this needs to be an exact solve to pass + (*coarse)(x_coarse[0], r_coarse[0]); // this needs to be an exact solve to pass setOutputPrefix(prefix); // restore output prefix - transfer->P(tmp2, x_coarse); + transfer->P(tmp2, x_coarse[0]); param.matResidual(tmp1, tmp2); tmp2 = param.B[i]; - logQuda(QUDA_SUMMARIZE, "Vector %d: norms %e %e ", i, norm2(param.B[i]), norm2(tmp1)); - logQuda(QUDA_SUMMARIZE, "relative residual = %e\n", sqrt(xmyNorm(tmp2, tmp1) / norm2(param.B[i])) ); + logQuda(QUDA_SUMMARIZE, "Vector %d: norms %e %e ", i, B_norm[i], norm2(tmp1)); + logQuda(QUDA_SUMMARIZE, "relative residual = %e\n", sqrt(xmyNorm(tmp2, tmp1) / B_norm[i]) ); } #endif // create coarse residual vector if not already created in verify() - if (r_coarse.empty()) - r_coarse = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r.Precision(), - param.mg_global.location[param.level + 1]); + if (r_coarse.empty()) { + r_coarse.resize(1); + r_coarse[0] = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r[0].Precision(), + param.mg_global.location[param.level + 1]); + } // create coarse solution vector if not already created in verify() - if (x_coarse.empty()) - x_coarse = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r.Precision(), - param.mg_global.location[param.level + 1]); + if (x_coarse.empty()) { + x_coarse.resize(1); + x_coarse[0] = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r[0].Precision(), + param.mg_global.location[param.level + 1]); + } { logQuda(QUDA_SUMMARIZE, "Checking 0 = (1 - P^\\dagger P) eta_c\n"); - spinorNoise(x_coarse, *rng, QUDA_NOISE_UNIFORM); - transfer->P(tmp2, x_coarse); - transfer->R(r_coarse, tmp2); - auto r2 = norm2(r_coarse); - auto max_deviation = blas::max_deviation(r_coarse, x_coarse); - auto l2_deviation = sqrt(xmyNorm(x_coarse, r_coarse) / norm2(x_coarse)); - logQuda(QUDA_VERBOSE, "L2 norms %e %e (fine tmp %e); Deviations: L2 relative = %e, max = %e\n", norm2(x_coarse), - r2, norm2(tmp2), l2_deviation, max_deviation[0]); + spinorNoise(x_coarse[0], *rng, QUDA_NOISE_UNIFORM); + transfer->P(tmp2, x_coarse[0]); + transfer->R(r_coarse[0], tmp2); + auto r2 = norm2(r_coarse[0]); + auto max_deviation = blas::max_deviation(r_coarse[0], x_coarse[0]); + auto l2_deviation = sqrt(xmyNorm(x_coarse[0], r_coarse[0]) / norm2(x_coarse[0])); + logQuda(QUDA_VERBOSE, "L2 norms %e %e (fine tmp %e); Deviations: L2 relative = %e, max = %e\n", + norm2(x_coarse[0]), r2, norm2(tmp2), l2_deviation, max_deviation[0]); if (check_deviation(l2_deviation, tol)) errorQuda("coarse span failed: L2 relative deviation = %e > %e", l2_deviation, tol); if (check_deviation(max_deviation[0], tol)) @@ -952,7 +965,7 @@ namespace quda (*param.matResidual)(tmp2, tmp1); } - transfer->R(x_coarse, tmp2); + transfer->R(x_coarse[0], tmp2); static_cast(diracCoarseResidual)->M(r_coarse, tmp_coarse); #if 0 // enable to print out emulated and actual coarse-grid operator vectors for debugging @@ -963,20 +976,20 @@ namespace quda printfQuda("\nemulated\n"); comm_barrier(); for (int parity = 0; parity < 2; parity++) - for (unsigned int x_cb = 0; x_cb < x_coarse.VolumeCB(); x_cb++) x_coarse.PrintVector(parity, x_cb, rank); + for (unsigned int x_cb = 0; x_cb < x_coarse[0].VolumeCB(); x_cb++) x_coarse[0].PrintVector(parity, x_cb, rank); comm_barrier(); printfQuda("\nactual\n"); comm_barrier(); for (int parity = 0; parity < 2; parity++) - for (unsigned int x_cb = 0; x_cb < r_coarse.VolumeCB(); x_cb++) r_coarse.PrintVector(parity, x_cb, rank); + for (unsigned int x_cb = 0; x_cb < r_coarse[0].VolumeCB(); x_cb++) r_coarse[0].PrintVector(parity, x_cb, rank); } setOutputPrefix(prefix); #endif - double r_nrm = norm2(r_coarse); - auto max_deviation = blas::max_deviation(r_coarse, x_coarse); - auto l2_deviation = sqrt(xmyNorm(x_coarse, r_coarse) / norm2(x_coarse)); + double r_nrm = norm2(r_coarse[0]); + auto max_deviation = blas::max_deviation(r_coarse[0], x_coarse[0]); + auto l2_deviation = sqrt(xmyNorm(x_coarse[0], r_coarse[0]) / norm2(x_coarse[0])); if (diracResidual->Mu() != 0.0) { // When the mu is shifted on the coarse level; we can compute exactly the error we introduce in the check: @@ -985,13 +998,13 @@ namespace quda if (fabs(delta_factor) > tol) { double delta_a = delta_factor * 2.0 * diracResidual->Kappa() * diracResidual->Mu() * transfer->Vectors().TwistFlavor(); - l2_deviation -= fabs(delta_a) * sqrt(norm2(tmp_coarse) / norm2(x_coarse)); + l2_deviation -= fabs(delta_a) * sqrt(norm2(tmp_coarse) / norm2(x_coarse[0])); l2_deviation = fabs(l2_deviation); max_deviation[0] -= fabs(delta_a); } } logQuda(QUDA_VERBOSE, "L2 norms: Emulated = %e, Native = %e; Deviations: L2 relative = %e, max = %e\n", - norm2(x_coarse), r_nrm, l2_deviation, max_deviation[0]); + norm2(x_coarse[0]), r_nrm, l2_deviation, max_deviation[0]); if (check_deviation(l2_deviation, tol)) errorQuda("Coarse operator failed: L2 relative deviation = %e > %e", l2_deviation, tol); @@ -1005,14 +1018,14 @@ namespace quda if (coarse_was_preconditioned) { // check eo logQuda(QUDA_SUMMARIZE, "Checking Deo of preconditioned operator 0 = \\hat{D}_c - A^{-1} D_c\n"); - static_cast(diracCoarseResidual)->Dslash(r_coarse.Even(), tmp_coarse.Odd(), QUDA_EVEN_PARITY); - static_cast(diracCoarseResidual)->CloverInv(x_coarse.Even(), r_coarse.Even(), QUDA_EVEN_PARITY); - static_cast(diracCoarseSmoother)->Dslash(r_coarse.Even(), tmp_coarse.Odd(), QUDA_EVEN_PARITY); - double r_nrm = norm2(r_coarse.Even()); - auto max_deviation = blas::max_deviation(r_coarse.Even(), x_coarse.Even()); - auto l2_deviation = sqrt(xmyNorm(x_coarse.Even(), r_coarse.Even()) / norm2(x_coarse.Even())); + static_cast(diracCoarseResidual)->Dslash(r_coarse[0].Even(), tmp_coarse.Odd(), QUDA_EVEN_PARITY); + static_cast(diracCoarseResidual)->CloverInv(x_coarse[0].Even(), r_coarse[0].Even(), QUDA_EVEN_PARITY); + static_cast(diracCoarseSmoother)->Dslash(r_coarse[0].Even(), tmp_coarse.Odd(), QUDA_EVEN_PARITY); + double r_nrm = norm2(r_coarse[0].Even()); + auto max_deviation = blas::max_deviation(r_coarse[0].Even(), x_coarse[0].Even()); + auto l2_deviation = sqrt(xmyNorm(x_coarse[0].Even(), r_coarse[0].Even()) / norm2(x_coarse[0].Even())); logQuda(QUDA_VERBOSE, "L2 norms: Emulated = %e, Native = %e; Deviations: L2 relative = %e, max = %e\n", - norm2(x_coarse.Even()), r_nrm, l2_deviation, max_deviation[0]); + norm2(x_coarse[0].Even()), r_nrm, l2_deviation, max_deviation[0]); if (check_deviation(l2_deviation, tol)) errorQuda("Preconditioned Deo failed: L2 relative deviation = %e > %e", l2_deviation, tol); if (check_deviation(max_deviation[0], tol)) @@ -1020,14 +1033,14 @@ namespace quda // check Doe logQuda(QUDA_SUMMARIZE, "Checking Doe of preconditioned operator 0 = \\hat{D}_c - A^{-1} D_c\n"); - static_cast(diracCoarseResidual)->Dslash(r_coarse.Odd(), tmp_coarse.Even(), QUDA_ODD_PARITY); - static_cast(diracCoarseResidual)->CloverInv(x_coarse.Odd(), r_coarse.Odd(), QUDA_ODD_PARITY); - static_cast(diracCoarseSmoother)->Dslash(r_coarse.Odd(), tmp_coarse.Even(), QUDA_ODD_PARITY); - r_nrm = norm2(r_coarse.Odd()); - max_deviation = blas::max_deviation(r_coarse.Odd(), x_coarse.Odd()); - l2_deviation = sqrt(xmyNorm(x_coarse.Odd(), r_coarse.Odd()) / norm2(x_coarse.Odd())); + static_cast(diracCoarseResidual)->Dslash(r_coarse[0].Odd(), tmp_coarse.Even(), QUDA_ODD_PARITY); + static_cast(diracCoarseResidual)->CloverInv(x_coarse[0].Odd(), r_coarse[0].Odd(), QUDA_ODD_PARITY); + static_cast(diracCoarseSmoother)->Dslash(r_coarse[0].Odd(), tmp_coarse.Even(), QUDA_ODD_PARITY); + r_nrm = norm2(r_coarse[0].Odd()); + max_deviation = blas::max_deviation(r_coarse[0].Odd(), x_coarse[0].Odd()); + l2_deviation = sqrt(xmyNorm(x_coarse[0].Odd(), r_coarse[0].Odd()) / norm2(x_coarse[0].Odd())); logQuda(QUDA_VERBOSE, "L2 norms: Emulated = %e, Native = %e; Deviations: L2 relative = %e, max = %e\n", - norm2(x_coarse.Odd()), r_nrm, l2_deviation, max_deviation[0]); + norm2(x_coarse[0].Odd()), r_nrm, l2_deviation, max_deviation[0]); if (check_deviation(l2_deviation, tol)) errorQuda("Preconditioned Doe failed: L2 relative deviation = %e > %e", l2_deviation, tol); if (check_deviation(max_deviation[0], tol)) @@ -1084,16 +1097,16 @@ namespace quda for (int i = 0; i < param.Nvec; i++) { // Restrict Evec, place result in r_coarse - transfer->R(r_coarse, param.B[i]); + transfer->R(r_coarse[0], param.B[i]); // Prolong r_coarse, place result in tmp2 - transfer->P(tmp2, r_coarse); + transfer->P(tmp2, r_coarse[0]); - printfQuda("Vector %d: norms v_k = %e P^dag v_k = %e PP^dag v_k = %e\n", i, norm2(param.B[i]), - norm2(r_coarse), norm2(tmp2)); + printfQuda("Vector %d: norms v_k = %e P^dag v_k = %e PP^dag v_k = %e\n", i, B_norm[i], norm2(r_coarse[0]), + norm2(tmp2)); // Compare v_k and PP^dag v_k. auto max_deviation = blas::max_deviation(tmp2, param.B[i]); - auto l2_deviation = sqrt(xmyNorm(param.B[i], tmp2) / norm2(param.B[i])); + auto l2_deviation = sqrt(xmyNorm(param.B[i], tmp2) / B_norm[i]); printfQuda("L2 relative deviation = %e max deviation = %e\n", l2_deviation, max_deviation[0]); if (param.mg_global.run_oblique_proj_check) { @@ -1105,17 +1118,16 @@ namespace quda // Oblique projections logQuda(QUDA_SUMMARIZE, "Checking 1 > || (1 - DP(P^dagDP)P^dag) v_k || / || v_k || for vector %d\n", i); - transfer->R(r_coarse, param.B[i]); - (*coarse_solver)(x_coarse, r_coarse); // this needs to be an exact solve to pass + transfer->R(r_coarse[0], param.B[i]); + (*coarse_solver)(x_coarse[0], r_coarse[0]); // this needs to be an exact solve to pass setOutputPrefix(prefix); // restore prefix after return from coarse grid - transfer->P(tmp2, x_coarse); + transfer->P(tmp2, x_coarse[0]); (*param.matResidual)(tmp1, tmp2); - logQuda(QUDA_SUMMARIZE, "Vector %d: norms v_k %e DP(P^dagDP)P^dag v_k %e\n", i, norm2(param.B[i]), - norm2(tmp1)); + logQuda(QUDA_SUMMARIZE, "Vector %d: norms v_k %e DP(P^dagDP)P^dag v_k %e\n", i, B_norm[i], norm2(tmp1)); max_deviation = blas::max_deviation(tmp1, param.B[i]); logQuda(QUDA_SUMMARIZE, "L2 relative deviation = %e, max deviation = %e\n", - sqrt(xmyNorm(param.B[i], tmp1) / norm2(param.B[i])), max_deviation[0]); + sqrt(xmyNorm(param.B[i], tmp1) / B_norm[i]), max_deviation[0]); } sprintf(prefix, "MG level %d (%s): ", param.level + 1, @@ -1130,8 +1142,12 @@ namespace quda popLevel(); } - void MG::operator()(ColorSpinorField &x, ColorSpinorField &b) + void MG::operator()(cvector_ref &x, cvector_ref &b) { + resize(r, b.size(), QUDA_NULL_FIELD_CREATE); + resize(r_coarse, b.size(), QUDA_NULL_FIELD_CREATE); + resize(x_coarse, b.size(), QUDA_NULL_FIELD_CREATE); + pushOutputPrefix(prefix); if (param.level < param.Nlevel - 1) { // set parity for the solver in the transfer operator @@ -1150,38 +1166,23 @@ namespace quda QudaSolutionType outer_solution_type = b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? QUDA_MAT_SOLUTION : QUDA_MATPC_SOLUTION; QudaSolutionType inner_solution_type = param.coarse_grid_solution_type; - if (debug) printfQuda("outer_solution_type = %d, inner_solution_type = %d\n", outer_solution_type, inner_solution_type); - if ( outer_solution_type == QUDA_MATPC_SOLUTION && inner_solution_type == QUDA_MAT_SOLUTION) errorQuda("Unsupported solution type combination"); if ( inner_solution_type == QUDA_MATPC_SOLUTION && param.smoother_solve_type != QUDA_DIRECT_PC_SOLVE) errorQuda("For this coarse grid solution type, a preconditioned smoother is required"); - if ( debug ) printfQuda("entering V-cycle with x2=%e, r2=%e\n", norm2(x), norm2(b)); - if (param.level < param.Nlevel-1) { - //transfer->setTransferGPU(false); // use this to force location of transfer (need to check if still works for multi-level) - // do the pre smoothing - if (debug) printfQuda("pre-smoothing b2=%e site subset %d\n", norm2(b), b.SiteSubset()); - - ColorSpinorField out, in; + std::vector out(b.size()), in(b.size()); diracSmoother->prepare(out, in, x, b, outer_solution_type); - ColorSpinorField b_tilde; - // if we're using preconditioning then allocate storage for the preconditioned source vector - if (param.smoother_solve_type == QUDA_DIRECT_PC_SOLVE) { - b_tilde = getFieldTmp(r.Even()); - b_tilde = in; // b_tilde holds either a copy of preconditioned source or a pointer to original source - } - if (presmoother) (*presmoother)(out, in); else zero(out); - ColorSpinorField &solution = inner_solution_type == outer_solution_type ? x : x.Even(); + auto &solution = inner_solution_type == outer_solution_type ? x : x.Even(); diracSmoother->reconstruct(solution, b, inner_solution_type); // if using preconditioned smoother then need to reconstruct full residual @@ -1194,16 +1195,17 @@ namespace quda false; // FIXME this is currently borked if inner solver is preconditioned - ColorSpinorField &residual = !presmoother ? b : + const auto &residual = !presmoother ? b : use_solver_residual ? presmoother->get_residual() : - b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? r : - r.Even(); + b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? cvector_ref(r) : + cvector_ref(r).Even(); if (!use_solver_residual && presmoother) { + auto &residual = b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? cvector_ref(r) : + cvector_ref(r).Even(); (*param.matResidual)(residual, x); axpby(1.0, b, -1.0, residual); } - double r2 = debug ? norm2(residual) : 0.0; // We need this to ensure that the coarse level has been created. // e.g. in case of iterative setup with MG we use just pre- and post-smoothing at the first iteration. @@ -1211,32 +1213,16 @@ namespace quda // restrict to the coarse grid transfer->R(r_coarse, residual); - if (debug) printfQuda("after pre-smoothing x2 = %e, r2 = %e, r_coarse2 = %e\n", norm2(x), r2, norm2(r_coarse)); // recurse to the next lower level (*coarse_solver)(x_coarse, r_coarse); - if (debug) printfQuda("after coarse solve x_coarse2 = %e r_coarse2 = %e\n", norm2(x_coarse), norm2(r_coarse)); // prolongate back to this grid - ColorSpinorField &x_coarse_2_fine - = inner_solution_type == QUDA_MAT_SOLUTION ? r : r.Even(); // define according to inner solution type + auto &x_coarse_2_fine = inner_solution_type == QUDA_MAT_SOLUTION ? + cvector_ref(r) : + cvector_ref(r).Even(); // define according to inner solution type transfer->P(x_coarse_2_fine, x_coarse); // repurpose residual storage xpy(x_coarse_2_fine, solution); // sum to solution FIXME - sum should be done inside the transfer operator - if ( debug ) { - printfQuda("Prolongated coarse solution y2 = %e\n", norm2(r)); - printfQuda("after coarse-grid correction x2 = %e, r2 = %e\n", norm2(x), norm2(r)); - } - } - - if (debug) printfQuda("preparing to post smooth\n"); - - // do the post smoothing - // residual = outer_solution_type == QUDA_MAT_SOLUTION ? r : r.Even(); // refine for outer solution type - if (param.smoother_solve_type == QUDA_DIRECT_PC_SOLVE) { - in = b_tilde.create_alias(); - } else { // this incurs unecessary copying - r = b; - in = r.create_alias(); } // we should keep a copy of the prepared right hand side as we've already destroyed it @@ -1245,25 +1231,15 @@ namespace quda if (postsmoother) (*postsmoother)(out, in); // for inner solve preconditioned, in the should be the original prepared rhs - if (debug) printfQuda("exited postsmooth, about to reconstruct\n"); - diracSmoother->reconstruct(x, b, outer_solution_type); - if (debug) printfQuda("finished reconstruct\n"); - } else { // do the coarse grid solve - ColorSpinorField out, in; + std::vector out(b.size()), in(b.size()); diracSmoother->prepare(out, in, x, b, outer_solution_type); if (presmoother) (*presmoother)(out, in); diracSmoother->reconstruct(x, b, outer_solution_type); - } - // FIXME on subset check - if (debug && b.SiteSubset() == r.SiteSubset()) { - (*param.matResidual)(r, x); - double r2 = xmyNorm(b, r); - printfQuda("leaving V-cycle with x2=%e, r2=%e\n", norm2(x), r2); } popOutputPrefix(); @@ -1343,7 +1319,7 @@ namespace quda } solverParam.pipeline = (solverParam.inv_type == QUDA_BICGSTAB_INVERTER ? 0 : 4); // FIXME: pipeline != 0 breaks BICGSTAB - solverParam.precision = r.Precision(); + solverParam.precision = r[0].Precision(); if (is_fine_grid()) { solverParam.precision_sloppy = param.mg_global.invert_param->cuda_prec_precondition; @@ -1355,13 +1331,14 @@ namespace quda solverParam.residual_type = static_cast(QUDA_L2_RELATIVE_RESIDUAL); solverParam.compute_null_vector = QUDA_COMPUTE_NULL_VECTOR_YES; ColorSpinorParam csParam(B[0]); // Create spinor field parameters: - csParam.setPrecision(r.Precision(), r.Precision(), true); // ensure native ordering + csParam.setPrecision(r[0].Precision(), r[0].Precision(), true); // ensure native ordering csParam.location = QUDA_CUDA_FIELD_LOCATION; // hard code to GPU location for null-space generation for now csParam.gammaBasis = B[0].Nspin() == 1 ? QUDA_DEGRAND_ROSSI_GAMMA_BASIS : QUDA_UKQCD_GAMMA_BASIS; // degrand-rossi required for staggered csParam.create = QUDA_ZERO_FIELD_CREATE; - ColorSpinorField b(csParam); - ColorSpinorField x(csParam); + std::vector b, x; + resize(b, param.n_vec_batch, csParam); + resize(x, param.n_vec_batch, csParam); csParam.create = QUDA_NULL_FIELD_CREATE; @@ -1428,25 +1405,35 @@ namespace quda } // launch solver for each source - for (auto i = 0u; i < B.size(); i++) { - if (param.mg_global.setup_type == QUDA_TEST_VECTOR_SETUP) { // DDalphaAMG test vector idea - b = B[i]; // inverting against the vector - zero(x); // with zero initial guess + if (B.size() % param.n_vec_batch != 0) errorQuda("Bad batch size %d", param.n_vec_batch); + for (auto i = 0u; i < B.size(); i += param.n_vec_batch) { + if (param.mg_global.setup_type + == QUDA_TEST_VECTOR_SETUP) { // DDalphaAMG test vector idea solving against the vector + copy({b.begin(), b.begin() + param.n_vec_batch}, {B.begin() + i, B.begin() + i + param.n_vec_batch}); + zero(x); // with zero initial guess } else { - x = B[i]; + copy({x.begin(), x.begin() + param.n_vec_batch}, {B.begin() + i, B.begin() + i + param.n_vec_batch}); zero(b); } - logQuda(QUDA_VERBOSE, "Initial guess = %g\n", norm2(x)); - logQuda(QUDA_VERBOSE, "Initial rhs = %g\n", norm2(b)); + if (getVerbosity() >= QUDA_VERBOSE) { + auto nrm2 = norm2(x); + auto b2 = norm2(b); + for (auto j = 0; j < param.n_vec_batch; j++) + printfQuda("%d Initial guess = %g, Initial rhs = %g\n", i + j, nrm2[j], b2[j]); + } - ColorSpinorField out, in; + std::vector out(param.n_vec_batch), in(param.n_vec_batch); diracSmoother->prepare(out, in, x, b, QUDA_MAT_SOLUTION); (*solve)(out, in); diracSmoother->reconstruct(x, b, QUDA_MAT_SOLUTION); - logQuda(QUDA_VERBOSE, "Solution = %g\n", norm2(x)); - B[i] = x; + if (getVerbosity() >= QUDA_VERBOSE) { + auto nrm2 = norm2(x); + for (auto j = 0; j < param.n_vec_batch; j++) printfQuda("%d Solution = %g\n", i + j, nrm2[j]); + } + + copy({B.begin() + i, B.begin() + i + param.n_vec_batch}, {x.begin(), x.begin() + param.n_vec_batch}); } // global orthonormalization of the generated null-space vectors @@ -1677,7 +1664,6 @@ namespace quda } logQuda(QUDA_VERBOSE, "Done building free vectors\n"); - popLevel(); } diff --git a/lib/quda_fortran.F90 b/lib/quda_fortran.F90 index 15ee245ea3..8eda362fbe 100644 --- a/lib/quda_fortran.F90 +++ b/lib/quda_fortran.F90 @@ -124,8 +124,8 @@ module quda_fortran real(8) :: tol_restart ! Solver tolerance in the L2 residual norm (used to restart InitCG) real(8) :: tol_hq ! Requested heavy quark residual norm integer(4) :: compute_true_res ! Whether to compute the true residual post solve - real(8) :: true_res ! Actual L2 residual norm achieved in solver - real(8) :: true_res_hq ! Actual heavy quark residual norm achieved in solver + real(8), dimension(QUDA_MAX_MULTI_SRC) :: true_res ! Actual L2 residual norm achieved in solver + real(8), dimension(QUDA_MAX_MULTI_SRC) :: true_res_hq ! Actual heavy quark residual norm achieved in solver integer(4) :: maxiter real(8) :: reliable_delta ! Reliable update tolerance real(8) :: reliable_delta_refinement ! Reliable update tolerance used in post multi-shift solver refinement @@ -219,6 +219,10 @@ module quda_fortran integer(4) :: iter real(8) :: gflops real(8) :: secs + real(8) :: energy + real(8) :: power + real(8) :: temp + real(8) :: clock ! Enable auto-tuning? QudaTune :: tune diff --git a/lib/reduce_quda.cu b/lib/reduce_quda.cu index 61790d2ec5..58e2e9b199 100644 --- a/lib/reduce_quda.cu +++ b/lib/reduce_quda.cu @@ -135,7 +135,7 @@ namespace quda { if (r.write.V) v.restore(); } - long long flops() const override { return r.flops() * x.Length(); } + long long flops() const override { return r.flops() * x.Length() * x.size(); } long long bytes() const override { diff --git a/lib/solve.cpp b/lib/solve.cpp new file mode 100644 index 0000000000..d4eb2bed8d --- /dev/null +++ b/lib/solve.cpp @@ -0,0 +1,421 @@ +#include "invert_quda.h" + +namespace quda +{ + + // vector of spinors used for forecasting solutions in HMC +#define QUDA_MAX_CHRONO 12 + // each entry is one p + std::vector> chronoResident(QUDA_MAX_CHRONO); + + void flushChrono(int i) + { + if (i >= QUDA_MAX_CHRONO) errorQuda("Requested chrono index %d is outside of max %d", i, QUDA_MAX_CHRONO); + + if (i >= 0) + chronoResident[i].clear(); + else + for (auto i = 0; i < QUDA_MAX_CHRONO; i++) chronoResident[i].clear(); + } + + void massRescale(cvector_ref &b, QudaInvertParam ¶m, bool for_multishift) + { + double kappa5 = (0.5 / (5.0 + param.m5)); + double kappa = (param.dslash_type == QUDA_DOMAIN_WALL_DSLASH || param.dslash_type == QUDA_DOMAIN_WALL_4D_DSLASH + || param.dslash_type == QUDA_MOBIUS_DWF_DSLASH || param.dslash_type == QUDA_MOBIUS_DWF_EOFA_DSLASH) ? + kappa5 : + param.kappa; + + logQuda(QUDA_DEBUG_VERBOSE, "Mass rescale: Kappa is: %g\n", kappa); + logQuda(QUDA_DEBUG_VERBOSE, "Mass rescale: mass normalization: %d\n", param.mass_normalization); + if (getVerbosity() > QUDA_DEBUG_VERBOSE) { + auto b2 = blas::norm2(b); + for (auto &b2i : b2) printfQuda("Mass rescale: norm of source in = %g\n", b2i); + } + + // staggered dslash uses mass normalization internally + if (param.dslash_type == QUDA_ASQTAD_DSLASH || param.dslash_type == QUDA_STAGGERED_DSLASH) { + switch (param.solution_type) { + case QUDA_MAT_SOLUTION: + case QUDA_MATPC_SOLUTION: + if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(2.0 * param.mass, b); + break; + case QUDA_MATDAG_MAT_SOLUTION: + case QUDA_MATPCDAG_MATPC_SOLUTION: + if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(4.0 * param.mass * param.mass, b); + break; + default: errorQuda("Not implemented"); + } + return; + } + + // multiply the source to compensate for normalization of the Dirac operator, if necessary + // you are responsible for restoring what's in param.offset + switch (param.solution_type) { + case QUDA_MAT_SOLUTION: + if (param.mass_normalization == QUDA_MASS_NORMALIZATION + || param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { + blas::ax(2.0 * kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 2.0 * kappa; + } + break; + case QUDA_MATDAG_MAT_SOLUTION: + if (param.mass_normalization == QUDA_MASS_NORMALIZATION + || param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { + blas::ax(4.0 * kappa * kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; + } + break; + case QUDA_MATPC_SOLUTION: + if (param.mass_normalization == QUDA_MASS_NORMALIZATION) { + blas::ax(4.0 * kappa * kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; + } else if (param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { + blas::ax(2.0 * kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 2.0 * kappa; + } + break; + case QUDA_MATPCDAG_MATPC_SOLUTION: + if (param.mass_normalization == QUDA_MASS_NORMALIZATION) { + blas::ax(16.0 * std::pow(kappa, 4), b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 16.0 * std::pow(kappa, 4); + } else if (param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { + blas::ax(4.0 * kappa * kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; + } + break; + default: errorQuda("Solution type %d not supported", param.solution_type); + } + + if (getVerbosity() > QUDA_DEBUG_VERBOSE) { + auto b2 = blas::norm2(b); + for (auto &b2i : b2) printfQuda("Mass rescale: norm of source out = %g\n", b2i); + } + } + + void distanceReweight(cvector_ref &b, QudaInvertParam ¶m, bool inverse) + { + // Force the alpha0 to be positive. + // A negative alpha0 matches something like Eq.(12) in arXiv:1006.4028. + // Disable the negative situation as QUDA already has multigrid for light quarks. + const double alpha0 = abs(param.distance_pc_alpha0); + const int t0 = param.distance_pc_t0; + if (alpha0 != 0.0 && t0 >= 0) { + if (param.dslash_type != QUDA_WILSON_DSLASH && param.dslash_type != QUDA_CLOVER_WILSON_DSLASH) { + errorQuda("Only Wilson and Wilson-clover dslash support distance preconditioning, but get dslash_type %d\n", + param.dslash_type); + } + if (param.inv_type == QUDA_MG_INVERTER) errorQuda("Multigrid solver doesn't support distance preconditioning"); + + if (param.cuda_prec != QUDA_DOUBLE_PRECISION || param.cuda_prec_sloppy != QUDA_DOUBLE_PRECISION) { + warningQuda( + "Using single or half (sloppy) precision in distance preconditioning sometimes makes the solver diverge"); + } + + if (inverse) + for (auto i = 0u; i < b.size(); i++) spinorDistanceReweight(b[i], -alpha0, t0); + else + for (auto i = 0u; i < b.size(); i++) spinorDistanceReweight(b[i], alpha0, t0); + } + } + + void solve(cvector_ref &x, cvector_ref &b, Dirac &dirac, Dirac &diracSloppy, + Dirac &diracPre, Dirac &diracEig, QudaInvertParam ¶m) + { + getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); + + bool mat_solution = (param.solution_type == QUDA_MAT_SOLUTION) || (param.solution_type == QUDA_MATPC_SOLUTION); + bool direct_solve = (param.solve_type == QUDA_DIRECT_SOLVE) || (param.solve_type == QUDA_DIRECT_PC_SOLVE); + bool norm_error_solve = (param.solve_type == QUDA_NORMERR_SOLVE) || (param.solve_type == QUDA_NORMERR_PC_SOLVE); + + auto nb = blas::norm2(b); + for (auto &bi : nb) { + if (bi == 0.0) errorQuda("Source has zero norm"); + logQuda(QUDA_VERBOSE, "Source: %g\n", bi); + } + if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { + auto x_norm = blas::norm2(x); + for (auto &xi : x_norm) logQuda(QUDA_VERBOSE, "Initial guess: %g\n", xi); + } + // rescale the source and solution vectors to help prevent the onset of underflow + if (param.solver_normalization == QUDA_SOURCE_NORMALIZATION) { + auto nb_inv(nb); + for (auto bi : nb_inv) bi = 1 / sqrt(bi); + blas::ax(nb_inv, b); + blas::ax(nb_inv, x); + } + + massRescale(b, param, false); + distanceReweight(b, param, true); + + std::vector in(b.size()); + std::vector out(b.size()); + + // if we're doing a managed memory MG solve and prefetching is + // enabled, prefetch all the Dirac matrices. There's probably + // a better place to put this... + if (param.inv_type_precondition == QUDA_MG_INVERTER) { + dirac.prefetch(QUDA_CUDA_FIELD_LOCATION); + diracSloppy.prefetch(QUDA_CUDA_FIELD_LOCATION); + diracPre.prefetch(QUDA_CUDA_FIELD_LOCATION); + } + + dirac.prepare(out, in, x, b, param.solution_type); + + if (getVerbosity() >= QUDA_VERBOSE) { + auto in_norm = blas::norm2(in); + auto out_norm = blas::norm2(out); + for (auto i = 0u; i < in.size(); i++) + logQuda(QUDA_VERBOSE, "Prepared: source = %g, solution = %g\n", in_norm[i], out_norm[i]); + } + + // solution_type specifies *what* system is to be solved. + // solve_type specifies *how* the system is to be solved. + // + // We have the following four cases (plus preconditioned variants): + // + // solution_type solve_type Effect + // ------------- ---------- ------ + // MAT DIRECT Solve Ax=b + // MATDAG_MAT DIRECT Solve A^dag y = b, followed by Ax=y + // MAT NORMOP Solve (A^dag A) x = (A^dag b) + // MATDAG_MAT NORMOP Solve (A^dag A) x = b + // MAT NORMERR Solve (A A^dag) y = b, then x = A^dag y + // + // We generally require that the solution_type and solve_type + // preconditioning match. As an exception, the unpreconditioned MAT + // solution_type may be used with any solve_type, including + // DIRECT_PC and NORMOP_PC. In these cases, preparation of the + // preconditioned source and reconstruction of the full solution are + // taken care of by Dirac::prepare() and Dirac::reconstruct(), + // respectively. + + getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); + + if (mat_solution && !direct_solve && !norm_error_solve) { // prepare source: b' = A^dag b + auto tmp = getFieldTmp(cvector_ref(in)); + blas::copy(tmp, in); + dirac.Mdag(in, tmp); + } else if (!mat_solution && direct_solve) { // perform the first of two solves: A^dag y = b + DiracMdag m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); + SolverParam solverParam(param); + Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); + (*solve)(out, in); + blas::copy(in, out); + delete solve; + solverParam.updateInvertParam(param); + } + + if (direct_solve) { + DiracM m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); + SolverParam solverParam(param); + + // chronological forecasting + if (param.chrono_use_resident && chronoResident[param.chrono_index].size() > 0) { + bool hermitian = false; + auto &mChrono = param.chrono_precision == param.cuda_prec ? m : mSloppy; + chronoExtrapolate(out[0], in[0], chronoResident[param.chrono_index], mChrono, hermitian); + } + + Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); + (*solve)(out, in); + delete solve; + solverParam.updateInvertParam(param); + } else if (!norm_error_solve) { + DiracMdagM m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); + SolverParam solverParam(param); + + // chronological forecasting + if (param.chrono_use_resident && chronoResident[param.chrono_index].size() > 0) { + bool hermitian = true; + auto &mChrono = param.chrono_precision == param.cuda_prec ? m : mSloppy; + chronoExtrapolate(out[0], in[0], chronoResident[param.chrono_index], mChrono, hermitian); + } + + // if using a Schwarz preconditioner with a normal operator then we must use the DiracMdagMLocal operator + if (param.inv_type_precondition != QUDA_INVALID_INVERTER && param.schwarz_type != QUDA_INVALID_SCHWARZ) { + DiracMdagMLocal mPreLocal(diracPre); + Solver *solve = Solver::create(solverParam, m, mSloppy, mPreLocal, mEig); + (*solve)(out, in); + delete solve; + solverParam.updateInvertParam(param); + } else { + Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); + (*solve)(out, in); + delete solve; + solverParam.updateInvertParam(param); + } + } else { // norm_error_solve + DiracMMdag m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); + auto tmp = getFieldTmp(cvector_ref(in)); + SolverParam solverParam(param); + Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); + (*solve)(tmp, in); // y = (M M^\dag) b + dirac.Mdag(out, tmp); // x = M^dag y + delete solve; + solverParam.updateInvertParam(param); + } + + if (getVerbosity() >= QUDA_VERBOSE) { + auto x_norm = blas::norm2(out); + for (auto i = 0u; i < x.size(); i++) printfQuda("Solution = %g\n", x_norm[i]); + } + + getProfile().TPSTART(QUDA_PROFILE_EPILOGUE); + if (param.chrono_make_resident) { + const int i = param.chrono_index; + if (i >= QUDA_MAX_CHRONO) errorQuda("Requested chrono index %d is outside of max %d\n", i, QUDA_MAX_CHRONO); + + auto &basis = chronoResident[i]; + + if (param.chrono_max_dim < (int)basis.size()) { + errorQuda("Requested chrono_max_dim %i is smaller than already existing chronology %lu", param.chrono_max_dim, + basis.size()); + } + + if (not param.chrono_replace_last) { + // if we have not filled the space yet just augment + if ((int)basis.size() < param.chrono_max_dim) { + ColorSpinorParam cs_param(out[0]); + cs_param.setPrecision(param.chrono_precision); + basis.emplace_back(cs_param); + } + + // shuffle every entry down one and bring the last to the front + std::rotate(basis.begin(), basis.end() - 1, basis.end()); + } + basis[0] = out[0]; // set first entry to new solution + } + + dirac.reconstruct(x, b, param.solution_type); + + distanceReweight(x, param, false); + + if (param.solver_normalization == QUDA_SOURCE_NORMALIZATION) { + // rescale the solution + for (auto bi : nb) bi = sqrt(bi); + blas::ax(nb, x); + } + + if (getVerbosity() >= QUDA_VERBOSE) { + auto x_norm = blas::norm2(x); + for (auto i = 0u; i < x.size(); i++) printfQuda("Reconstructed Solution = %g\n", x_norm[i]); + } + + if (param.compute_action) { + auto action = blas::cDotProduct(b, x); + param.action[0] = action[0].real(); + param.action[1] = action[0].imag(); + } + + getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); + } + + void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam ¶m, + const bool pc_solve); + + extern std::vector solutionResident; + + void solve(const std::vector &hp_x, const std::vector &hp_b, QudaInvertParam ¶m, + const GaugeField &u) + { + pushVerbosity(param.verbosity); + + if (hp_b.size() != hp_x.size()) + errorQuda("Number of solutions %lu != number of solves %lu", hp_x.size(), hp_b.size()); + int n_src = hp_b.size(); + + // It was probably a bad design decision to encode whether the system is even/odd preconditioned (PC) in + // solve_type and solution_type, rather than in separate members of QudaInvertParam. We're stuck with it + // for now, though, so here we factorize everything for convenience. + + bool pc_solution + = (param.solution_type == QUDA_MATPC_SOLUTION) || (param.solution_type == QUDA_MATPCDAG_MATPC_SOLUTION); + bool pc_solve = (param.solve_type == QUDA_DIRECT_PC_SOLVE) || (param.solve_type == QUDA_NORMOP_PC_SOLVE) + || (param.solve_type == QUDA_NORMERR_PC_SOLVE); + + param.iter = 0; + + Dirac *dirac = nullptr; + Dirac *diracSloppy = nullptr; + Dirac *diracPre = nullptr; + Dirac *diracEig = nullptr; + + // Create the dirac operator and operators for sloppy, precondition, + // and an eigensolver + createDiracWithEig(dirac, diracSloppy, diracPre, diracEig, param, pc_solve); + + // wrap CPU host side pointers + ColorSpinorParam cpuParam(hp_b[0], param, u.X(), pc_solution, param.input_location); + std::vector h_b(n_src); + for (auto i = 0u; i < h_b.size(); i++) { + cpuParam.v = hp_b[i]; + h_b[i] = ColorSpinorField(cpuParam); + } + + std::vector h_x(n_src); + cpuParam.location = param.output_location; + for (auto i = 0u; i < h_x.size(); i++) { + cpuParam.v = hp_x[i]; + h_x[i] = ColorSpinorField(cpuParam); + } + + // download source + ColorSpinorParam cudaParam(cpuParam, param, QUDA_CUDA_FIELD_LOCATION); + cudaParam.create = QUDA_NULL_FIELD_CREATE; + std::vector b; + resize(b, n_src, cudaParam); + blas::copy(b, h_b); + + // now check if we need to invalidate the solutionResident vectors + std::vector x; + resize(x, n_src, cudaParam); + if (param.use_resident_solution == 1) { + for (auto &v : solutionResident) { + if (b[0].Precision() != v.Precision() || b[0].SiteSubset() != v.SiteSubset()) { + solutionResident.clear(); + break; + } + } + + if (!solutionResident.size()) { + cudaParam.create = QUDA_NULL_FIELD_CREATE; + solutionResident = std::vector(1, cudaParam); + } + x[0] = solutionResident[0].create_alias(cudaParam); + } else { + cudaParam.create = QUDA_NULL_FIELD_CREATE; + x[0] = ColorSpinorField(cudaParam); + } + + if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES && !param.chrono_use_resident) { // download initial guess + // initial guess only supported for single-pass solvers + if ((param.solution_type == QUDA_MATDAG_MAT_SOLUTION || param.solution_type == QUDA_MATPCDAG_MATPC_SOLUTION) + && (param.solve_type == QUDA_DIRECT_SOLVE || param.solve_type == QUDA_DIRECT_PC_SOLVE)) { + errorQuda("Initial guess not supported for two-pass solver"); + } + + blas::copy(x, h_x); // solution + } else { // zero initial guess + blas::zero(x); + } + + solve(x, b, *dirac, *diracSloppy, *diracPre, *diracEig, param); + + if (!param.make_resident_solution) blas::copy(h_x, x); + + delete dirac; + delete diracSloppy; + delete diracPre; + delete diracEig; + + popVerbosity(); + } + +} // namespace quda diff --git a/lib/solver.cpp b/lib/solver.cpp index 357e30f067..2f9be27701 100644 --- a/lib/solver.cpp +++ b/lib/solver.cpp @@ -34,10 +34,13 @@ namespace quda { } } - void Solver::create(ColorSpinorField &x, const ColorSpinorField &b) + void Solver::create(cvector_ref &x, cvector_ref &b) { if (checkPrecision(x, b) != param.precision) errorQuda("Precision mismatch %d %d", checkPrecision(x, b), param.precision); + + param.true_res.resize(b.size()); + param.true_res_hq.resize(b.size()); } // solver factory @@ -78,11 +81,11 @@ namespace quda { break; case QUDA_CA_CGNE_INVERTER: report("CA-CGNE"); - solver = new CACGNE(mat, matSloppy, matPrecon, matEig, param); + solver = new CGNE(mat, matSloppy, matPrecon, matEig, param); break; case QUDA_CA_CGNR_INVERTER: report("CA-CGNR"); - solver = new CACGNR(mat, matSloppy, matPrecon, matEig, param); + solver = new CGNR(mat, matSloppy, matPrecon, matEig, param); break; case QUDA_CA_GCR_INVERTER: report("CA-GCR"); @@ -103,9 +106,9 @@ namespace quda { static_cast(param.preconditioner)->mg; // FIXME dirty hack to ensure that preconditioner precision set in interface isn't used in the outer GCR-MG solver if (!param.mg_instance) param.precision_precondition = param.precision_sloppy; - solver = new PreconCG(mat, *(mg), matSloppy, matPrecon, matEig, param); + solver = new PCG(mat, *(mg), matSloppy, matPrecon, matEig, param); } else { - solver = new PreconCG(mat, matSloppy, matPrecon, matEig, param); + solver = new PCG(mat, matSloppy, matPrecon, matEig, param); } break; case QUDA_BICGSTABL_INVERTER: @@ -145,11 +148,11 @@ namespace quda { break; case QUDA_CG3NE_INVERTER: report("CG3NE"); - solver = new CG3NE(mat, matSloppy, matPrecon, param); + solver = new CGNE(mat, matSloppy, matPrecon, matEig, param); break; case QUDA_CG3NR_INVERTER: report("CG3NR"); - solver = new CG3NR(mat, matSloppy, matPrecon, param); + solver = new CGNR(mat, matSloppy, matPrecon, matEig, param); break; default: errorQuda("Invalid solver type %d", param.inv_type); @@ -212,9 +215,8 @@ namespace quda { errorQuda("Unexpected preconditioned solver %d", outer.inv_type); } - // this sets a fixed iteration count if we're using the MR solver - inner.residual_type - = (outer.inv_type_precondition == QUDA_MR_INVERTER) ? QUDA_INVALID_RESIDUAL : QUDA_L2_RELATIVE_RESIDUAL; + // allows the inner solver to early exit if it converges quickly + inner.residual_type = QUDA_L2_RELATIVE_RESIDUAL; inner.iter = 0; inner.inv_type_precondition = QUDA_INVALID_INVERTER; @@ -368,108 +370,151 @@ namespace quda { { for (int i = 0; i < param.num_src; i++) { (*this)(out.Component(i), in.Component(i)); - param.true_res_offset[i] = param.true_res; - param.true_res_hq_offset[i] = param.true_res_hq; + param.true_res_offset[i] = static_cast(param.true_res); + param.true_res_hq_offset[i] = static_cast(param.true_res_hq); } } - double Solver::stopping(double tol, double b2, QudaResidualType residual_type) + vector Solver::stopping(double tol, cvector &b2, QudaResidualType residual_type) { - double stop=0.0; + vector stop(b2.size(), 0.0); if ( (residual_type & QUDA_L2_ABSOLUTE_RESIDUAL) && (residual_type & QUDA_L2_RELATIVE_RESIDUAL) ) { - // use the most stringent stopping condition - double lowest = (b2 < 1.0) ? b2 : 1.0; - stop = lowest*tol*tol; + for (auto i = 0u; i < b2.size(); i++) { + // use the most stringent stopping condition + double lowest = (b2[i] < 1.0) ? b2[i] : 1.0; + stop[i] = lowest * tol * tol; + } } else if (residual_type & QUDA_L2_ABSOLUTE_RESIDUAL) { - stop = tol*tol; + for (auto i = 0u; i < b2.size(); i++) stop[i] = tol * tol; + } else if (residual_type & QUDA_L2_RELATIVE_RESIDUAL) { + for (auto i = 0u; i < b2.size(); i++) stop[i] = b2[i] * tol * tol; } else { - stop = b2*tol*tol; + // if invalid residual then convergence is set by iteration count only + for (auto i = 0u; i < b2.size(); i++) stop[i] = 0.0; } return stop; } - bool Solver::convergence(double r2, double hq2, double r2_tol, double hq_tol) { - - // check the heavy quark residual norm if necessary - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - if (std::isnan(hq2) || std::isinf(hq2)) - errorQuda("Solver appears to have diverged with heavy quark residual %9.6e", hq2); - - if (hq2 > hq_tol) return false; - } + bool Solver::convergence(cvector &r2, cvector &hq2, cvector &r2_tol, cvector &hq_tol) + { + if (r2.size() != hq2.size() || r2.size() != r2_tol.size() || r2.size() != hq_tol.size()) + errorQuda("Mismatched vector lengths r2 = %lu hq2 = %lu r2_tol = %lu hq_tol = %lu", r2.size(), hq2.size(), + r2_tol.size(), hq_tol.size()); - // check the L2 relative residual norm if necessary - if ((param.residual_type & QUDA_L2_RELATIVE_RESIDUAL) || (param.residual_type & QUDA_L2_ABSOLUTE_RESIDUAL)) { - if (std::isnan(r2) || std::isinf(r2)) errorQuda("Solver appears to have diverged with residual %9.6e", r2); + for (auto i = 0u; i < r2.size(); i++) { + // check the heavy quark residual norm if necessary + if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { + if (std::isnan(hq2[i]) || std::isinf(hq2[i])) + errorQuda("Solver appears to have diverged with heavy quark residual %9.6e", hq2[i]); + if (hq2[i] > hq_tol[i]) return false; + } - if (r2 > r2_tol) return false; + // check the L2 relative residual norm if necessary + if ((param.residual_type & QUDA_L2_RELATIVE_RESIDUAL) || (param.residual_type & QUDA_L2_ABSOLUTE_RESIDUAL)) { + if (std::isnan(r2[i]) || std::isinf(r2[i])) + errorQuda("Solver appears to have diverged with residual %9.6e", r2[i]); + if (r2[i] > r2_tol[i]) return false; + } } - return true; } - bool Solver::convergenceHQ(double, double hq2, double, double hq_tol) + bool Solver::convergenceHQ(cvector &hq2, cvector &hq_tol) { - // check the heavy quark residual norm if necessary - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - if (std::isnan(hq2) || std::isinf(hq2)) - errorQuda("Solver appears to have diverged with heavy quark residual %9.6e", hq2); + if (hq2.size() != hq_tol.size()) + errorQuda("Mismatched vector lengths hq2 = %lu hq_tol = %lu", hq2.size(), hq_tol.size()); - if (hq2 > hq_tol) return false; + for (auto i = 0u; i < hq2.size(); i++) { + // check the heavy quark residual norm if necessary + if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { + if (std::isnan(hq2[i]) || std::isinf(hq2[i])) + errorQuda("Solver appears to have diverged with heavy quark residual %9.6e", hq2[i]); + if (hq2[i] > hq_tol[i]) return false; + } } - return true; } - bool Solver::convergenceL2(double r2, double, double r2_tol, double) + bool Solver::convergenceL2(cvector &r2, cvector &r2_tol) { - // check the L2 relative residual norm if necessary - if ((param.residual_type & QUDA_L2_RELATIVE_RESIDUAL) || (param.residual_type & QUDA_L2_ABSOLUTE_RESIDUAL)) { - if (std::isnan(r2) || std::isinf(r2)) errorQuda("Solver appears to have diverged with residual %9.6e", r2); + if (r2.size() != r2_tol.size()) + errorQuda("Mismatched vector lengths r2 = %lu r2_tol = %lu", r2.size(), r2_tol.size()); - if (r2 > r2_tol) return false; + for (auto i = 0u; i < r2.size(); i++) { + // check the L2 relative residual norm if necessary + if ((param.residual_type & QUDA_L2_RELATIVE_RESIDUAL) || (param.residual_type & QUDA_L2_ABSOLUTE_RESIDUAL)) { + if (std::isnan(r2[i]) || std::isinf(r2[i])) + errorQuda("Solver appears to have diverged with residual %9.6e", r2[i]); + if (r2[i] > r2_tol[i]) return false; + } } - return true; } - void Solver::PrintStats(const char* name, int k, double r2, double b2, double hq2) { - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - logQuda(QUDA_VERBOSE, "%s: %5d iterations, = %9.6e, |r|/|b| = %9.6e, heavy-quark residual = %9.6e\n", name, - k, r2, sqrt(r2 / b2), hq2); - } else { - logQuda(QUDA_VERBOSE, "%s: %5d iterations, = %9.6e, |r|/|b| = %9.6e\n", name, k, r2, sqrt(r2 / b2)); - } - - if (std::isnan(r2) || std::isinf(r2)) errorQuda("Solver appears to have diverged"); + std::string set_rhs_str(unsigned int i, size_t n) + { + std::string rhs_str; + if (n > 1) rhs_str += "n = " + std::to_string(i) + std::string(", "); + return rhs_str; } - void Solver::PrintSummary(const char *name, int k, double r2, double b2, - double r2_tol, double hq_tol) { - if (param.compute_true_res) { + void Solver::PrintStats(const char *name, int k, cvector &r2, cvector &b2, cvector &hq2_) + { + auto hq2 = hq2_.size() == 0 ? vector(r2.size(), 0.0) : hq2_; + if (r2.size() != b2.size() || r2.size() != hq2.size()) + errorQuda("Mismatched vector lengths r2 = %lu b2 = %lu hq2 = %lu", r2.size(), b2.size(), hq2.size()); + + for (auto i = 0u; i < r2.size(); i++) { + auto rhs_str = set_rhs_str(i, r2.size()); if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - logQuda(QUDA_SUMMARIZE, - "%s: Convergence at %d iterations, L2 relative residual: iterated = %9.6e, true = %9.6e " - "(requested = %9.6e), heavy-quark residual = %9.6e (requested = %9.6e)\n", - name, k, sqrt(r2 / b2), param.true_res, sqrt(r2_tol / b2), param.true_res_hq, hq_tol); + logQuda(QUDA_VERBOSE, "%s: %5d iterations, %s = %9.6e, |r|/|b| = %9.6e, heavy-quark residual = %9.6e\n", + name, k, rhs_str.c_str(), r2[i], sqrt(r2[i] / b2[i]), hq2[i]); } else { - logQuda(QUDA_SUMMARIZE, - "%s: Convergence at %d iterations, L2 relative residual: iterated = %9.6e, true = %9.6e " - "(requested = %9.6e)\n", - name, k, sqrt(r2 / b2), param.true_res, sqrt(r2_tol / b2)); + logQuda(QUDA_VERBOSE, "%s: %5d iterations, %s = %9.6e, |r|/|b| = %9.6e\n", name, k, rhs_str.c_str(), r2[i], + sqrt(r2[i] / b2[i])); } - } else { - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - logQuda(QUDA_SUMMARIZE, - "%s: Convergence at %d iterations, L2 relative residual: iterated = %9.6e " - "(requested = %9.6e), heavy-quark residual = %9.6e (requested = %9.6e)\n", - name, k, sqrt(r2 / b2), sqrt(r2_tol / b2), param.true_res_hq, hq_tol); + + if (std::isnan(r2[i]) || std::isinf(r2[i])) errorQuda("Solver appears to have diverged for n = %d", i); + } + } + + void Solver::PrintSummary(const char *name, int k, cvector &r2, cvector &b2, cvector &r2_tol, + cvector &hq_tol_) + { + auto hq_tol = hq_tol_.size() == 0 ? vector(r2.size(), 0.0) : hq_tol_; + if (r2.size() != b2.size() || r2.size() != r2_tol.size() || r2.size() != hq_tol.size()) + errorQuda("Mismatched vector lengths r2 = %lu b2 = %lu r2_tol = %lu hq_tol = %lu", r2.size(), b2.size(), + r2_tol.size(), hq_tol.size()); + + for (auto i = 0u; i < r2.size(); i++) { + auto rhs_str = set_rhs_str(i, r2.size()); + if (param.compute_true_res) { + if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { + logQuda(QUDA_SUMMARIZE, + "%s: Convergence at %d iterations, %sL2 relative residual: iterated = %9.6e, true = %9.6e " + "(requested = %9.6e), heavy-quark residual = %9.6e (requested = %9.6e)\n", + name, k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), param.true_res[i], sqrt(r2_tol[i] / b2[i]), + param.true_res_hq[i], hq_tol[i]); + } else { + logQuda(QUDA_SUMMARIZE, + "%s: Convergence at %d iterations, %sL2 relative residual: iterated = %9.6e, true = %9.6e " + "(requested = %9.6e)\n", + name, k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), param.true_res[i], sqrt(r2_tol[i] / b2[i])); + } } else { - logQuda(QUDA_SUMMARIZE, - "%s: Convergence at %d iterations, L2 relative residual: iterated = %9.6e (requested = %9.6e)\n", name, - k, sqrt(r2 / b2), sqrt(r2_tol / b2)); + if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { + logQuda(QUDA_SUMMARIZE, + "%s: Convergence at %d iterations, %sL2 relative residual: iterated = %9.6e " + "(requested = %9.6e), heavy-quark residual = %9.6e (requested = %9.6e)\n", + name, k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), sqrt(r2_tol[i] / b2[i]), param.true_res_hq[i], + hq_tol[i]); + } else { + logQuda(QUDA_SUMMARIZE, + "%s: Convergence at %d iterations, %sL2 relative residual: iterated = %9.6e (requested = %9.6e)\n", + name, k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), sqrt(r2_tol[i] / b2[i])); + } } } } @@ -525,4 +570,92 @@ namespace quda { } } + // check we're not solving on a zero-valued source + bool Solver::is_zero_src(cvector_ref &x, cvector_ref &b, cvector &b2) + { + // if computing null vectors then zero sources are fine + if (param.compute_null_vector != QUDA_COMPUTE_NULL_VECTOR_NO) return false; + + bool zero_src = true; + for (auto i = 0u; i < b.size(); i++) { + if (b2[i] == 0) { + warningQuda("source %d is zero", i); + x[i] = b[i]; + param.true_res[i] = 0.0; + param.true_res_hq[i] = 0.0; + } else { + zero_src = false; + } + } + return zero_src; + } + + void SolverParam::updateInvertParam(QudaInvertParam ¶m, int offset) + { + for (auto i = 0u; i < true_res.size(); i++) param.true_res[i] = true_res[i]; + for (auto i = 0u; i < true_res_hq.size(); i++) param.true_res_hq[i] = true_res_hq[i]; + param.iter += iter; + if (offset >= 0) { + param.true_res_offset[offset] = true_res_offset[offset]; + param.iter_res_offset[offset] = iter_res_offset[offset]; + param.true_res_hq_offset[offset] = true_res_hq_offset[offset]; + } else { + for (int i = 0; i < num_offset; i++) { + param.true_res_offset[i] = true_res_offset[i]; + param.iter_res_offset[i] = iter_res_offset[i]; + param.true_res_hq_offset[i] = true_res_hq_offset[i]; + } + } + // for incremental eigCG: + param.rhs_idx = rhs_idx; + + param.ca_lambda_min = ca_lambda_min; + param.ca_lambda_max = ca_lambda_max; + + param.ca_lambda_min_precondition = ca_lambda_min_precondition; + param.ca_lambda_max_precondition = ca_lambda_max_precondition; + + if (deflate) *static_cast(param.eig_param) = eig_param; + } + + void joinInvertParam(QudaInvertParam &out, const QudaInvertParam &in, const CommKey &split_key, int split_rank) + { + auto num_sub_partition = quda::product(split_key); + + int sub_partition_dims[] + = {comm_dim(0) / split_key[0], comm_dim(1) / split_key[1], comm_dim(2) / split_key[2], comm_dim(3) / split_key[3]}; + + int sub_partition_coords[] = {comm_coord(0) / sub_partition_dims[0], comm_coord(1) / sub_partition_dims[1], + comm_coord(2) / sub_partition_dims[2], comm_coord(3) / sub_partition_dims[3]}; + + auto j = sub_partition_coords[3]; + for (auto d = 2; d >= 0; d--) j = j * split_key[d] + sub_partition_coords[d]; + + std::vector true_res(in.num_src, 0.0); + std::vector true_res_hq(in.num_src, 0.0); + if (split_rank == 0) { // only rank 0 in each sub partition sets the residuals + for (auto i = 0; i < in.num_src_per_sub_partition; i++) { + true_res[i * num_sub_partition + j] = in.true_res[i]; + true_res_hq[i * num_sub_partition + j] = in.true_res_hq[i]; + } + } + + // communicate to all ranks + comm_allreduce_sum(true_res_hq); + comm_allreduce_sum(true_res); + memcpy(out.true_res, true_res.data(), true_res.size() * sizeof(double)); + memcpy(out.true_res_hq, true_res_hq.data(), true_res_hq.size() * sizeof(double)); + + out.iter = in.iter; + comm_allreduce_int(out.iter); + + out.ca_lambda_min = in.ca_lambda_min; + out.ca_lambda_max = in.ca_lambda_max; + out.ca_lambda_min_precondition = in.ca_lambda_min_precondition; + out.ca_lambda_max_precondition = in.ca_lambda_max_precondition; + + // now broadcast from global rank 0 to ensure uniformity + comm_broadcast(&out, sizeof(QudaInvertParam)); + } + } // namespace quda diff --git a/lib/solver.hpp b/lib/solver.hpp index ab93451aef..2482b8e6b3 100644 --- a/lib/solver.hpp +++ b/lib/solver.hpp @@ -51,8 +51,8 @@ namespace quda /** @brief Generate a Krylov space in a given basis @param[in] diracm Dirac matrix used to generate the Krylov space - @param[out] Ap dirac matrix times the Krylov basis vectors - @param[in,out] p Krylov basis vectors; assumes p[0] is in place + @param[out] Ap Dirac matrix times the Krylov basis vector sets + @param[in,out] p Krylov basis vector set; assumes p[0] is in place @param[in] n_krylov Size of krylov space @param[in] basis Basis type @param[in] m_map Slope mapping for Chebyshev basis; ignored for power basis @@ -60,9 +60,9 @@ namespace quda @param[in] args Parameter pack of ColorSpinorFields used as temporary passed to Dirac */ template - void Solver::computeCAKrylovSpace(const DiracMatrix &diracm, std::vector &Ap, - std::vector &p, int n_krylov, QudaCABasis basis, double m_map, - double b_map, Args &&...args) + void Solver::computeCAKrylovSpace(const DiracMatrix &diracm, std::vector> &Ap, + std::vector> &p, int n_krylov, QudaCABasis basis, + double m_map, double b_map, Args &&...args) { // in some cases p or Ap may be larger if (static_cast(p.size()) < n_krylov) errorQuda("Invalid p.size() %lu < n_krylov %d", p.size(), n_krylov); diff --git a/lib/targets/cuda/target_cuda.cmake b/lib/targets/cuda/target_cuda.cmake index 0f9484b1b8..eb71bb9d53 100644 --- a/lib/targets/cuda/target_cuda.cmake +++ b/lib/targets/cuda/target_cuda.cmake @@ -82,19 +82,6 @@ endif() # CUDA specific QUDA options include(CMakeDependentOption) -# large arg support requires CUDA 12.1 -cmake_dependent_option(QUDA_LARGE_KERNEL_ARG "enable large kernel arg support" ON "${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.1" OFF ) -message(STATUS "Large kernel arguments supported: ${QUDA_LARGE_KERNEL_ARG}") -mark_as_advanced(QUDA_LARGE_KERNEL_ARG) - -# Set the maximum multi-RHS per kernel -if(QUDA_LARGE_KERNEL_ARG) - set(QUDA_MAX_MULTI_RHS "64" CACHE STRING "maximum number of simultaneous RHS in a kernel") -else() - set(QUDA_MAX_MULTI_RHS "16" CACHE STRING "maximum number of simultaneous RHS in a kernel") -endif() -message(STATUS "Max number of rhs per kernel: ${QUDA_MAX_MULTI_RHS}") - option(QUDA_VERBOSE_BUILD "display kernel register usage" OFF) option(QUDA_JITIFY "build QUDA using Jitify" OFF) option(QUDA_DOWNLOAD_NVSHMEM "Download NVSHMEM" OFF) @@ -139,6 +126,21 @@ endif() set_target_properties(quda PROPERTIES CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES}) +# large arg support requires CUDA 12.1 and Volta+ +cmake_dependent_option(QUDA_LARGE_KERNEL_ARG "enable large kernel arg support" ON + "${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.1 AND ${QUDA_COMPUTE_CAPABILITY} GREATER_EQUAL 70" + OFF) +message(STATUS "Large kernel arguments supported: ${QUDA_LARGE_KERNEL_ARG}") +mark_as_advanced(QUDA_LARGE_KERNEL_ARG) + +# Set the maximum multi-RHS per kernel +if(QUDA_LARGE_KERNEL_ARG) + set(QUDA_MAX_MULTI_RHS "64" CACHE STRING "maximum number of simultaneous RHS in a kernel") +else() + set(QUDA_MAX_MULTI_RHS "16" CACHE STRING "maximum number of simultaneous RHS in a kernel") +endif() +message(STATUS "Max number of rhs per kernel: ${QUDA_MAX_MULTI_RHS}") + # QUDA_HASH for tunecache set(HASH cpu_arch=${CPU_ARCH},gpu_arch=${QUDA_GPU_ARCH},cuda_version=${CMAKE_CUDA_COMPILER_VERSION}) set(GITVERSION "${PROJECT_VERSION}-${GITVERSION}-${QUDA_GPU_ARCH}") diff --git a/lib/timer.cpp b/lib/timer.cpp index f8ed6a17c7..154467ea8f 100644 --- a/lib/timer.cpp +++ b/lib/timer.cpp @@ -1,6 +1,7 @@ #include #include #include +#include "monitor.h" #ifdef INTERFACE_NVTX #include "nvtx3/nvToolsExt.h" @@ -240,15 +241,44 @@ namespace quda { static std::stack tp_stack; - pushProfile::pushProfile(TimeProfile &profile, double &secs, double &gflops) : - profile(profile), secs(secs), gflops(gflops), flops(Tunable::flops_global()) + static double double_dummy; + + pushProfile::pushProfile(TimeProfile &profile, QudaInvertParam *param) : + profile(profile), + secs(param ? param->secs : double_dummy), + gflops(param ? param->gflops : double_dummy), + energy(param ? param->energy : double_dummy), + power(param ? param->power : double_dummy), + temp(param ? param->temp : double_dummy), + clock(param ? param->clock : double_dummy), + flops(Tunable::flops_global()) + { + if (profile.Name() != getProfile().Name()) { + // only push to stack if this profile not already the active one + profile.TPSTART(QUDA_PROFILE_TOTAL); + tp_stack.push(&profile); + active = true; + monitor_start = monitor::size(); + } + } + pushProfile::pushProfile(TimeProfile &profile, QudaQuarkSmearParam *param) : + profile(profile), + secs(param ? param->secs : double_dummy), + gflops(param ? param->gflops : double_dummy), + energy(param ? param->energy : double_dummy), + power(param ? param->power : double_dummy), + temp(param ? param->temp : double_dummy), + clock(param ? param->clock : double_dummy), + flops(Tunable::flops_global()) { if (profile.Name() != getProfile().Name()) { // only push to stack if this profile not already the active one profile.TPSTART(QUDA_PROFILE_TOTAL); tp_stack.push(&profile); active = true; + comm_barrier(); + monitor_start = monitor::size(); } } @@ -260,9 +290,37 @@ namespace quda { if (&(this->profile) != &profile) errorQuda("Popped profile is not the expected one"); tp_stack.pop(); profile.TPSTOP(QUDA_PROFILE_TOTAL); + secs = profile.Last(QUDA_PROFILE_TOTAL); + comm_allreduce_max(secs); + gflops = (Tunable::flops_global() - flops) * 1e-9; - if (&gflops != &gflops_dummy) comm_allreduce_sum(gflops); + if (&gflops != &double_dummy) comm_allreduce_sum(gflops); + + // make sure all processes will start + std::vector monitor_start_global = {static_cast(monitor_start)}; + comm_allreduce_min(monitor_start_global); + if (monitor_start_global[0] > 0) { + monitor_end = monitor::size(); + auto mean_state = monitor::mean(monitor_start, monitor_end); + energy = mean_state.energy; + comm_allreduce_sum(energy); + + power = mean_state.power; + comm_allreduce_sum(power); + power /= comm_size(); + + temp = mean_state.temp; + comm_allreduce_sum(temp); + temp /= comm_size(); + + clock = mean_state.clock; + comm_allreduce_sum(clock); + clock /= comm_size(); + } + + // cache is written out even if a long benchmarking job gets interrupted + saveTuneCache(); } } diff --git a/lib/transfer.cpp b/lib/transfer.cpp index ae106b92ac..b3925ccc4c 100644 --- a/lib/transfer.cpp +++ b/lib/transfer.cpp @@ -339,6 +339,7 @@ namespace quda { // apply the prolongator void Transfer::P(cvector_ref &out, cvector_ref &in) const { getProfile().TPSTART(QUDA_PROFILE_COMPUTE); + if (out.size() != in.size()) errorQuda("Mismatched set sizes %lu != %lu", out.size(), in.size()); initializeLazy(use_gpu ? QUDA_CUDA_FIELD_LOCATION : QUDA_CPU_FIELD_LOCATION); const int *fine_to_coarse = use_gpu ? fine_to_coarse_d : fine_to_coarse_h; @@ -413,6 +414,7 @@ namespace quda { void Transfer::R(cvector_ref &out, cvector_ref &in) const { getProfile().TPSTART(QUDA_PROFILE_COMPUTE); + if (out.size() != in.size()) errorQuda("Mismatched set sizes %lu != %lu", out.size(), in.size()); initializeLazy(use_gpu ? QUDA_CUDA_FIELD_LOCATION : QUDA_CPU_FIELD_LOCATION); const int *fine_to_coarse = use_gpu ? fine_to_coarse_d : fine_to_coarse_h; diff --git a/lib/tune.cpp b/lib/tune.cpp index cdfa24660b..3bcd7da964 100644 --- a/lib/tune.cpp +++ b/lib/tune.cpp @@ -105,7 +105,7 @@ namespace quda const std::string get_resource_path() { - static std::string resource_path; + static std::string resource_path = {}; static bool init = false; if (!init) { @@ -114,10 +114,8 @@ namespace quda if (!path) { warningQuda("Environment variable QUDA_RESOURCE_PATH is not set."); - return {}; } else if (stat(path, &pstat) || !S_ISDIR(pstat.st_mode)) { warningQuda("The path \"%s\" specified by QUDA_RESOURCE_PATH does not exist or is not a directory.", path); - return {}; } else { resource_path = path; } @@ -440,7 +438,11 @@ namespace quda auto &resource_path = get_resource_path(); if (resource_path.empty()) { - warningQuda("Caching of tuned parameters will be disabled"); + static bool init = false; + if (!init) { + warningQuda("Caching of tuned parameters will be disabled"); + init = true; + } return; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a26599564f..b295e00d9c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1004,7 +1004,7 @@ endif() if(QUDA_DIRAC_WILSON) add_test(NAME invert_test_wilson COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} - --dslash-type wilson --ngcrkrylov 8 + --dslash-type wilson --ngcrkrylov 8 --nsrc 4 --nsrc-tile 4 --dim 2 4 6 8 --niter 1000 --enable-testing true --gtest_output=xml:invert_test_wilson.xml) @@ -1032,7 +1032,7 @@ if(QUDA_DIRAC_TWISTED_MASS) --matpc even-even --enable-testing true --gtest_output=xml:invert_test_twisted_mass_sym.xml) - add_test(NAME invert_test_twisted_mass_asym} + add_test(NAME invert_test_twisted_mass_asym COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} --dslash-type twisted-mass --dim 2 4 6 8 --niter 1000 --ngcrkrylov 8 @@ -1204,7 +1204,7 @@ foreach(prec IN LISTS TEST_PRECS) COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} --dslash-type asqtad --ngcrkrylov 8 --compute-fat-long true --dim 6 6 6 8 --prec ${prec} --tol ${tol} --tolhq ${tol} --niter 1000 - --enable-testing true + --enable-testing true --nsrc 4 --nsrc-tile 4 --gtest_output=xml:invert_test_asqtad_${prec}.xml) if(DEFINED ENV{QUDA_ENABLE_TUNING}) diff --git a/tests/asan.h b/tests/asan.h new file mode 100644 index 0000000000..1c18092210 --- /dev/null +++ b/tests/asan.h @@ -0,0 +1,13 @@ +#pragma once + +extern "C" { + +/** + @brief Set the default ASAN options. This ensures that QUDA just + works when SANITIZE is enabled without requiring ASAN_OPTIONS to + be set. We default disable leak checking, otherwise this will + cause ctest to fail with MPI library leaks. This declaration + cannot be in the test library, and must be in the test executable. +*/ +const char *__asan_default_options() { return "detect_leaks=0,protect_shadow_gap=0"; } +} diff --git a/tests/blas_test.cpp b/tests/blas_test.cpp index 3a62f4ac9e..0bb6980c54 100644 --- a/tests/blas_test.cpp +++ b/tests/blas_test.cpp @@ -17,6 +17,8 @@ // google test #include +#include "test.h" + using namespace quda; /** diff --git a/tests/clover_force_test.cpp b/tests/clover_force_test.cpp index 7488b0856d..4a9edb57b2 100644 --- a/tests/clover_force_test.cpp +++ b/tests/clover_force_test.cpp @@ -4,6 +4,7 @@ #include "clover_force_reference.h" #include "misc.h" +#include "test.h" #include // convenient quark field container #include #include diff --git a/tests/contract_ft_test.cpp b/tests/contract_ft_test.cpp index 552a0527db..cc793ede5b 100644 --- a/tests/contract_ft_test.cpp +++ b/tests/contract_ft_test.cpp @@ -10,6 +10,7 @@ #include #include #include "misc.h" +#include "test.h" // google test #include @@ -113,7 +114,6 @@ inline void fill_buffers(std::array, N> &buffs, const std::ar srand(l); for (int i = 0; i < dofs; i++) { -#pragma unroll for (int n = 0; n < N; n++) { buffs[n][ll * dofs + i] = 2. * (rand() / (Float)RAND_MAX) - 1.; } } } diff --git a/tests/deflated_invert_test.cpp b/tests/deflated_invert_test.cpp index f6ca57e043..e0f74be86c 100644 --- a/tests/deflated_invert_test.cpp +++ b/tests/deflated_invert_test.cpp @@ -277,7 +277,7 @@ int main(int argc, char **argv) double l2r = sqrt(nrm2 / src2); printfQuda("Residuals: (L2 relative) tol %g, QUDA = %g, host = %g; (heavy-quark) tol %g, QUDA = %g\n", inv_param.tol, - inv_param.true_res, l2r, inv_param.tol_hq, inv_param.true_res_hq); + inv_param.true_res[0], l2r, inv_param.tol_hq, inv_param.true_res_hq[0]); freeGaugeQuda(); if (dslash_type == QUDA_CLOVER_WILSON_DSLASH || dslash_type == QUDA_TWISTED_CLOVER_DSLASH) freeCloverQuda(); diff --git a/tests/dslash_test_utils.h b/tests/dslash_test_utils.h index 914edb247d..01041fcf26 100644 --- a/tests/dslash_test_utils.h +++ b/tests/dslash_test_utils.h @@ -26,6 +26,7 @@ #include #include +#include "test.h" using namespace quda; @@ -319,8 +320,10 @@ struct DslashTestWrapper { dirac = Dirac::create(diracParam); } else { - double cpu_norm = blas::norm2(spinor); - printfQuda("Source: CPU = %e\n", cpu_norm); + for (int i = 0; i < Nsrc; i++) { + double cpu_norm = blas::norm2(spinor[i]); + printfQuda("Source %d: CPU = %e\n", i, cpu_norm); + } } } diff --git a/tests/eigensolve_test.cpp b/tests/eigensolve_test.cpp index 5b063755cc..e325511e05 100644 --- a/tests/eigensolve_test.cpp +++ b/tests/eigensolve_test.cpp @@ -14,6 +14,7 @@ #include #include #include +#include "test.h" // Place params above "eigensolve_test_gtest.hpp" so they // are visible therein. diff --git a/tests/gauge_path_test.cpp b/tests/gauge_path_test.cpp index fee31cb53c..1efa49ef97 100644 --- a/tests/gauge_path_test.cpp +++ b/tests/gauge_path_test.cpp @@ -12,6 +12,7 @@ #include #include #include +#include "test.h" static QudaGaugeFieldOrder gauge_order = QUDA_QDP_GAUGE_ORDER; diff --git a/tests/hisq_paths_force_test.cpp b/tests/hisq_paths_force_test.cpp index 7560dbc105..d696ca1558 100644 --- a/tests/hisq_paths_force_test.cpp +++ b/tests/hisq_paths_force_test.cpp @@ -7,6 +7,7 @@ #include #include "gauge_field.h" #include "misc.h" +#include "test.h" #include "hisq_force_reference.h" #include "ks_improved_force.h" #include "momentum.h" diff --git a/tests/hisq_stencil_ctest.cpp b/tests/hisq_stencil_ctest.cpp index 6df6d1d977..811032ded2 100644 --- a/tests/hisq_stencil_ctest.cpp +++ b/tests/hisq_stencil_ctest.cpp @@ -1,3 +1,4 @@ +#include "test.h" #include "hisq_stencil_test_utils.h" using namespace quda; diff --git a/tests/hisq_stencil_test.cpp b/tests/hisq_stencil_test.cpp index 3b20287d3b..cc6581733d 100644 --- a/tests/hisq_stencil_test.cpp +++ b/tests/hisq_stencil_test.cpp @@ -1,3 +1,4 @@ +#include "test.h" #include "hisq_stencil_test_utils.h" using namespace quda; diff --git a/tests/hisq_unitarize_force_test.cpp b/tests/hisq_unitarize_force_test.cpp index 01e3c78c18..04e1576f20 100644 --- a/tests/hisq_unitarize_force_test.cpp +++ b/tests/hisq_unitarize_force_test.cpp @@ -7,6 +7,7 @@ #include #include "gauge_field.h" #include "misc.h" +#include "test.h" #include "hisq_force_reference.h" #include "ks_improved_force.h" #include diff --git a/tests/host_reference/dslash_reference.cpp b/tests/host_reference/dslash_reference.cpp index 8cf63d2e27..5506514f85 100644 --- a/tests/host_reference/dslash_reference.cpp +++ b/tests/host_reference/dslash_reference.cpp @@ -12,26 +12,27 @@ // Overload for workflows without multishift std::array verifyInversion(void *spinorOut, void *spinorIn, void *spinorCheck, QudaGaugeParam &gauge_param, - QudaInvertParam &inv_param, void **gauge, void *clover, void *clover_inv) + QudaInvertParam &inv_param, void **gauge, void *clover, void *clover_inv, + int src_idx) { void **spinorOutMulti = nullptr; return verifyInversion(spinorOut, spinorOutMulti, spinorIn, spinorCheck, gauge_param, inv_param, gauge, clover, - clover_inv); + clover_inv, src_idx); } std::array verifyInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, void *spinorCheck, QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, void **gauge, - void *clover, void *clover_inv) + void *clover, void *clover_inv, int src_idx) { std::array res = {std::numeric_limits::max(), std::numeric_limits::max()}; if (dslash_type == QUDA_DOMAIN_WALL_DSLASH || dslash_type == QUDA_DOMAIN_WALL_4D_DSLASH || dslash_type == QUDA_MOBIUS_DWF_DSLASH || dslash_type == QUDA_MOBIUS_DWF_EOFA_DSLASH) { res = verifyDomainWallTypeInversion(spinorOut, spinorOutMulti, spinorIn, spinorCheck, gauge_param, inv_param, gauge, - clover, clover_inv); + clover, clover_inv, src_idx); } else if (dslash_type == QUDA_WILSON_DSLASH || dslash_type == QUDA_CLOVER_WILSON_DSLASH || dslash_type == QUDA_TWISTED_MASS_DSLASH || dslash_type == QUDA_TWISTED_CLOVER_DSLASH) { res = verifyWilsonTypeInversion(spinorOut, spinorOutMulti, spinorIn, spinorCheck, gauge_param, inv_param, gauge, - clover, clover_inv); + clover, clover_inv, src_idx); } else { errorQuda("Unsupported dslash_type=%s", get_dslash_str(dslash_type)); } @@ -40,7 +41,7 @@ std::array verifyInversion(void *spinorOut, void **spinorOutMulti, vo std::array verifyDomainWallTypeInversion(void *spinorOut, void **, void *spinorIn, void *spinorCheck, QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, - void **gauge, void *, void *) + void **gauge, void *, void *, int src_idx) { if (multishift > 1) errorQuda("Multishift not supported"); @@ -163,15 +164,15 @@ std::array verifyDomainWallTypeInversion(void *spinorOut, void **, vo double l2r = sqrt(nrm2 / src2); printfQuda("Residuals: (L2 relative) tol %9.6e, QUDA = %9.6e, host = %9.6e; (heavy-quark) tol %9.6e, QUDA = %9.6e\n", - inv_param.tol, inv_param.true_res, l2r, inv_param.tol_hq, inv_param.true_res_hq); + inv_param.tol, inv_param.true_res[src_idx], l2r, inv_param.tol_hq, inv_param.true_res_hq[src_idx]); return {l2r, inv_param.tol_hq}; ; } -std::array verifyWilsonTypeInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, - void *spinorCheck, QudaGaugeParam &gauge_param, - QudaInvertParam &inv_param, void **gauge, void *clover, void *clover_inv) +std::array verifyWilsonTypeInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, void *spinorCheck, + QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, void **gauge, + void *clover, void *clover_inv, int src_idx) { int vol = (inv_param.solution_type == QUDA_MAT_SOLUTION || inv_param.solution_type == QUDA_MATDAG_MAT_SOLUTION ? V : Vh); @@ -409,7 +410,7 @@ std::array verifyWilsonTypeInversion(void *spinorOut, void **spinorOu printfQuda( "Residuals: (L2 relative) tol %9.6e, QUDA = %9.6e, host = %9.6e; (heavy-quark) tol %9.6e, QUDA = %9.6e\n", - inv_param.tol, inv_param.true_res, l2r, inv_param.tol_hq, inv_param.true_res_hq); + inv_param.tol, inv_param.true_res[src_idx], l2r, inv_param.tol_hq, inv_param.true_res_hq[src_idx]); } return {l2r_max, inv_param.tol_hq}; @@ -745,17 +746,17 @@ double verifyWilsonTypeSingularVector(void *spinor_left, void *spinor_right, dou std::array verifyStaggeredInversion(quda::ColorSpinorField &in, quda::ColorSpinorField &out, quda::GaugeField &fat_link, quda::GaugeField &long_link, - QudaInvertParam &inv_param) + QudaInvertParam &inv_param, int src_idx) { std::vector out_vector(1); out_vector[0] = out; - return verifyStaggeredInversion(in, out_vector, fat_link, long_link, inv_param); + return verifyStaggeredInversion(in, out_vector, fat_link, long_link, inv_param, src_idx); } std::array verifyStaggeredInversion(quda::ColorSpinorField &in, std::vector &out_vector, quda::GaugeField &fat_link, quda::GaugeField &long_link, - QudaInvertParam &inv_param) + QudaInvertParam &inv_param, int src_idx) { int dagger = inv_param.dagger == QUDA_DAG_YES ? 1 : 0; double l2r_max = 0.0; @@ -834,7 +835,7 @@ std::array verifyStaggeredInversion(quda::ColorSpinorField &in, printfQuda("Residuals: (L2 relative) tol %9.6e, QUDA = %9.6e, host = %9.6e; (heavy-quark) tol %9.6e, QUDA = %9.6e, " "host = %9.6e\n", - inv_param.tol, inv_param.true_res, l2r, inv_param.tol_hq, inv_param.true_res_hq, hqr); + inv_param.tol, inv_param.true_res[src_idx], l2r, inv_param.tol_hq, inv_param.true_res_hq[src_idx], hqr); l2r_max = l2r; hqr_max = hqr; diff --git a/tests/host_reference/dslash_reference.h b/tests/host_reference/dslash_reference.h index cf7f176c4c..83b8e5934e 100644 --- a/tests/host_reference/dslash_reference.h +++ b/tests/host_reference/dslash_reference.h @@ -87,16 +87,17 @@ static inline void su3Tmul(sFloat *res, const gFloat *mat, const sFloat *vec) } std::array verifyInversion(void *spinorOut, void *spinorIn, void *spinorCheck, QudaGaugeParam &gauge_param, - QudaInvertParam &inv_param, void **gauge, void *clover, void *clover_inv); + QudaInvertParam &inv_param, void **gauge, void *clover, void *clover_inv, + int src_idx); std::array verifyInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, void *spinorCheck, QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, void **gauge, - void *clover, void *clover_inv); + void *clover, void *clover_inv, int src_idx = 0); std::array verifyDomainWallTypeInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, void *spinorCheck, QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, void **gauge, void *clover, - void *clover_inv); + void *clover_inv, int src_idx); double verifyWilsonTypeEigenvector(void *spinor, double _Complex lambda, int i, QudaGaugeParam &gauge_param, QudaEigParam &eig_param, void **gauge, void *clover, void *clover_inv); @@ -105,9 +106,9 @@ double verifyWilsonTypeSingularVector(void *spinor_left, void *spinor_right, dou QudaGaugeParam &gauge_param, QudaEigParam &eig_param, void **gauge, void *clover, void *clover_inv); -std::array verifyWilsonTypeInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, - void *spinorCheck, QudaGaugeParam &gauge_param, - QudaInvertParam &inv_param, void **gauge, void *clover, void *clover_inv); +std::array verifyWilsonTypeInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, void *spinorCheck, + QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, void **gauge, + void *clover, void *clover_inv, int src_idx); /** * @brief Verify a staggered inversion on the host. This version is a thin wrapper around a version that takes @@ -122,7 +123,7 @@ std::array verifyWilsonTypeInversion(void *spinorOut, void **spinorOu */ std::array verifyStaggeredInversion(quda::ColorSpinorField &in, quda::ColorSpinorField &out, quda::GaugeField &fat_link, quda::GaugeField &long_link, - QudaInvertParam &inv_param); + QudaInvertParam &inv_param, int src_idx); /** * @brief Verify a single- or multi-shift staggered inversion on the host @@ -137,7 +138,7 @@ std::array verifyStaggeredInversion(quda::ColorSpinorField &in, quda: std::array verifyStaggeredInversion(quda::ColorSpinorField &in, std::vector &out_vector, quda::GaugeField &fat_link, quda::GaugeField &long_link, - QudaInvertParam &inv_param); + QudaInvertParam &inv_param, int src_idx = 0); /** * @brief Verify a staggered-type eigenvector diff --git a/tests/invert_test.cpp b/tests/invert_test.cpp index bf56dc2162..2c945cc77c 100644 --- a/tests/invert_test.cpp +++ b/tests/invert_test.cpp @@ -13,6 +13,7 @@ #include #include #include +#include "test.h" QudaGaugeParam gauge_param; QudaInvertParam inv_param; @@ -21,6 +22,7 @@ QudaInvertParam mg_inv_param; QudaEigParam mg_eig_param[QUDA_MAX_MG_LEVEL]; QudaEigParam eig_param; bool use_split_grid = false; +bool use_multi_src = false; std::vector gauge_; std::array gauge; @@ -47,6 +49,7 @@ void display_test_info() printfQuda(" - number of levels %d\n", mg_levels); for (int i = 0; i < mg_levels - 1; i++) { printfQuda(" - level %d number of null-space vectors %d\n", i + 1, nvec[i]); + printfQuda(" - level %d null-space vector batch size %d\n", i + 1, nvec_batch[i]); printfQuda(" - level %d number of pre-smoother applications %d\n", i + 1, nu_pre[i]); printfQuda(" - level %d number of post-smoother applications %d\n", i + 1, nu_post[i]); } @@ -225,6 +228,7 @@ std::vector> solve(test_t param) for (int i = 0; i < 4; i++) inv_param.split_grid[i] = grid_partition[i]; int num_sub_partition = grid_partition[0] * grid_partition[1] * grid_partition[2] * grid_partition[3]; use_split_grid = num_sub_partition > 1; + use_multi_src = use_split_grid || (Nsrc_tile > 1); // Now QUDA is initialised and the fields are loaded, we may setup the preconditioner void *mg_preconditioner = nullptr; @@ -233,11 +237,18 @@ std::vector> solve(test_t param) mg_preconditioner = newMultigridQuda(&mg_param); inv_param.preconditioner = mg_preconditioner; - printfQuda("MG Setup Done: %g secs, %g Gflops\n", mg_param.secs, mg_param.gflops / mg_param.secs); + printfQuda("MG Setup Done: %g secs, %g Gflops\n", mg_param.invert_param->secs, + mg_param.invert_param->gflops / mg_param.invert_param->secs); + if (mg_param.invert_param->energy > 0) { + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", mg_param.invert_param->energy, + mg_param.invert_param->power, mg_param.invert_param->temp, mg_param.invert_param->clock); + } } // Vector construct START //----------------------------------------------------------------------------------- + if (Nsrc > QUDA_MAX_MULTI_SRC) + errorQuda("Nsrc = %d which is great than QUDA_MAX_MULTI_SRC = %d\n", Nsrc, QUDA_MAX_MULTI_SRC); std::vector in(Nsrc); std::vector out(Nsrc); std::vector out_multishift(multishift * Nsrc); @@ -304,7 +315,7 @@ std::vector> solve(test_t param) verifySpinorDistanceReweight(in[0], distance_pc_alpha0, distance_pc_t0); } - if (!use_split_grid) { + if (!use_multi_src || multishift > 1) { for (int i = 0; i < Nsrc; i++) { // If deflating, preserve the deflation space between solves @@ -316,34 +327,58 @@ std::vector> solve(test_t param) invertQuda(out[i].data(), in[i].data(), &inv_param); } + // move residuals to i^th location for verification after solves have finised + inv_param.true_res[i] = inv_param.true_res[0]; + inv_param.true_res_hq[i] = inv_param.true_res_hq[0]; + time[i] = inv_param.secs; gflops[i] = inv_param.gflops / inv_param.secs; iter[i] = inv_param.iter; printfQuda("Done: %i iter / %g secs = %g Gflops\n", inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs); + if (inv_param.energy > 0) { + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", inv_param.energy, + inv_param.power, inv_param.temp, inv_param.clock); + } } + } else { - inv_param.num_src = Nsrc; - inv_param.num_src_per_sub_partition = Nsrc / num_sub_partition; + inv_param.num_src = Nsrc_tile; + inv_param.num_src_per_sub_partition = Nsrc_tile / num_sub_partition; // Host arrays for solutions, sources, and check - std::vector _hp_x(Nsrc); - std::vector _hp_b(Nsrc); - for (int i = 0; i < Nsrc; i++) { - _hp_x[i] = out[i].data(); - _hp_b[i] = in[i].data(); + std::vector _hp_x(Nsrc_tile); + std::vector _hp_b(Nsrc_tile); + + for (int j = 0; j < Nsrc; j += Nsrc_tile) { + // If deflating, preserve the deflation space between solves + if (inv_deflate) eig_param.preserve_deflation = j < Nsrc - Nsrc_tile ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; + + for (int i = 0; i < Nsrc_tile; i++) { + _hp_x[i] = out[j + i].data(); + _hp_b[i] = in[j + i].data(); + } + + invertMultiSrcQuda(_hp_x.data(), _hp_b.data(), &inv_param); + + // move residuals to (i+j)^th location for verification after solves have finished + for (int i = 0; i < Nsrc_tile; i++) { + inv_param.true_res[j + i] = inv_param.true_res[i]; + inv_param.true_res_hq[j + i] = inv_param.true_res_hq[i]; + } + + quda::comm_allreduce_int(inv_param.iter); + inv_param.iter /= quda::comm_size() / num_sub_partition; + quda::comm_allreduce_sum(inv_param.gflops); + inv_param.gflops /= quda::comm_size() / num_sub_partition; + quda::comm_allreduce_max(inv_param.secs); + printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition, + inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile); + if (inv_param.energy > 0) { + printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n", + inv_param.energy, inv_param.energy / Nsrc_tile, inv_param.power, inv_param.temp, inv_param.clock); + } } - - // Run split grid - invertMultiSrcQuda(_hp_x.data(), _hp_b.data(), &inv_param); - - quda::comm_allreduce_int(inv_param.iter); - inv_param.iter /= quda::comm_size() / num_sub_partition; - quda::comm_allreduce_sum(inv_param.gflops); - inv_param.gflops /= quda::comm_size() / num_sub_partition; - quda::comm_allreduce_max(inv_param.secs); - printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops\n", num_sub_partition, inv_param.iter, - inv_param.secs, inv_param.gflops / inv_param.secs); } // QUDA invert test COMPLETE @@ -353,14 +388,14 @@ std::vector> solve(test_t param) if (inv_multigrid) destroyMultigridQuda(mg_preconditioner); // Compute performance statistics - if (Nsrc > 1 && !use_split_grid) performanceStats(time, gflops, iter); + if (!use_multi_src) performanceStats(time, gflops, iter); std::vector> res(Nsrc); // Perform host side verification of inversion if requested if (verify_results) { for (int i = 0; i < Nsrc; i++) { res[i] = verifyInversion(out[i].data(), _hp_multi_x[i].data(), in[i].data(), check.data(), gauge_param, inv_param, - gauge.data(), clover.data(), clover_inv.data()); + gauge.data(), clover.data(), clover_inv.data(), i); } } return res; diff --git a/tests/multigrid_evolve_test.cpp b/tests/multigrid_evolve_test.cpp index c77af2ed92..1ac7107131 100644 --- a/tests/multigrid_evolve_test.cpp +++ b/tests/multigrid_evolve_test.cpp @@ -63,6 +63,7 @@ void display_test_info() printfQuda(" - level %d number of null-space vectors %d\n", i + 1, nvec[i]); printfQuda(" - level %d number of pre-smoother applications %d\n", i + 1, nu_pre[i]); printfQuda(" - level %d number of post-smoother applications %d\n", i + 1, nu_post[i]); + printfQuda(" - level %d null-space vector batch size %d\n", i + 1, nvec_batch[i]); } printfQuda("Outer solver paramers\n"); diff --git a/tests/staggered_dslash_test_utils.h b/tests/staggered_dslash_test_utils.h index 4d336d8d4f..a9b499d71a 100644 --- a/tests/staggered_dslash_test_utils.h +++ b/tests/staggered_dslash_test_utils.h @@ -20,6 +20,7 @@ #include #include #include +#include "test.h" using namespace quda; diff --git a/tests/staggered_eigensolve_test.cpp b/tests/staggered_eigensolve_test.cpp index 335367b478..93ae358896 100644 --- a/tests/staggered_eigensolve_test.cpp +++ b/tests/staggered_eigensolve_test.cpp @@ -14,6 +14,7 @@ #include #include #include +#include "test.h" QudaGaugeParam gauge_param; QudaInvertParam eig_inv_param; diff --git a/tests/staggered_gsmear_test.cpp b/tests/staggered_gsmear_test.cpp index e5ce29c346..84f793dcb9 100644 --- a/tests/staggered_gsmear_test.cpp +++ b/tests/staggered_gsmear_test.cpp @@ -1,3 +1,4 @@ +#include "test.h" #include "staggered_gsmear_test_utils.h" using namespace quda; diff --git a/tests/staggered_invert_test.cpp b/tests/staggered_invert_test.cpp index 9a00b3d875..a098324ef6 100644 --- a/tests/staggered_invert_test.cpp +++ b/tests/staggered_invert_test.cpp @@ -17,6 +17,7 @@ #include #include #include +#include "test.h" QudaGaugeParam gauge_param; QudaInvertParam inv_param; @@ -25,6 +26,7 @@ QudaInvertParam mg_inv_param; QudaEigParam mg_eig_param[QUDA_MAX_MG_LEVEL]; QudaEigParam eig_param; bool use_split_grid = false; +bool use_multi_src = false; // print instructions on how to run the old tests bool print_legacy_info = false; @@ -48,6 +50,7 @@ void display_test_info() printfQuda(" - number of levels %d\n", mg_levels); for (int i = 0; i < mg_levels - 1; i++) { printfQuda(" - level %d number of null-space vectors %d\n", i + 1, nvec[i]); + printfQuda(" - level %d null-space vector batch size %d\n", i + 1, nvec_batch[i]); printfQuda(" - level %d number of pre-smoother applications %d\n", i + 1, nu_pre[i]); printfQuda(" - level %d number of post-smoother applications %d\n", i + 1, nu_post[i]); } @@ -271,6 +274,7 @@ std::vector> solve(test_t param) for (int i = 0; i < 4; i++) inv_param.split_grid[i] = grid_partition[i]; int num_sub_partition = grid_partition[0] * grid_partition[1] * grid_partition[2] * grid_partition[3]; use_split_grid = num_sub_partition > 1; + use_multi_src = use_split_grid || (Nsrc_tile > 1); // Setup the multigrid preconditioner void *mg_preconditioner = nullptr; @@ -279,11 +283,18 @@ std::vector> solve(test_t param) mg_preconditioner = newMultigridQuda(&mg_param); inv_param.preconditioner = mg_preconditioner; - printfQuda("MG Setup Done: %g secs, %g Gflops\n", mg_param.secs, mg_param.gflops / mg_param.secs); + printfQuda("MG Setup Done: %g secs, %g Gflops\n", mg_param.invert_param->secs, + mg_param.invert_param->gflops / mg_param.invert_param->secs); + if (mg_param.invert_param->energy > 0) { + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", mg_param.invert_param->energy, + mg_param.invert_param->power, mg_param.invert_param->temp, mg_param.invert_param->clock); + } } // Staggered vector construct START //----------------------------------------------------------------------------------- + if (Nsrc > QUDA_MAX_MULTI_SRC) + errorQuda("Nsrc = %d which is great than QUDA_MAX_MULTI_SRC = %d\n", Nsrc, QUDA_MAX_MULTI_SRC); std::vector in(Nsrc); std::vector out(Nsrc); std::vector out_multishift(Nsrc * multishift); @@ -361,7 +372,7 @@ std::vector> solve(test_t param) // QUDA invert test //---------------------------------------------------------------------------- - if (!use_split_grid) { + if (!use_multi_src || multishift > 1) { for (int n = 0; n < Nsrc; n++) { // If deflating, preserve the deflation space between solves @@ -373,39 +384,62 @@ std::vector> solve(test_t param) invertQuda(out[n].data(), in[n].data(), &inv_param); } + // move residuals to n^th location for verification after solves have finished + inv_param.true_res[n] = inv_param.true_res[0]; + inv_param.true_res_hq[n] = inv_param.true_res_hq[0]; + time[n] = inv_param.secs; gflops[n] = inv_param.gflops / inv_param.secs; iter[n] = inv_param.iter; - printfQuda("Done: %i iter / %g secs = %g Gflops\n\n", inv_param.iter, inv_param.secs, + printfQuda("Done: %i iter / %g secs = %g Gflops\n", inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs); + if (inv_param.energy > 0) { + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n\n", inv_param.energy, + inv_param.power, inv_param.temp, inv_param.clock); + } } } else { - inv_param.num_src = Nsrc; - inv_param.num_src_per_sub_partition = Nsrc / num_sub_partition; + + inv_param.num_src = Nsrc_tile; + inv_param.num_src_per_sub_partition = Nsrc_tile / num_sub_partition; // Host arrays for solutions, sources, and check - std::vector _hp_x(Nsrc); - std::vector _hp_b(Nsrc); - for (int n = 0; n < Nsrc; n++) { - _hp_x[n] = out[n].data(); - _hp_b[n] = in[n].data(); + std::vector _hp_x(Nsrc_tile); + std::vector _hp_b(Nsrc_tile); + + for (int j = 0; j < Nsrc; j += Nsrc_tile) { + for (int i = 0; i < Nsrc_tile; i++) { + _hp_x[i] = out[j + i].data(); + _hp_b[i] = in[j + i].data(); + } + + if (inv_deflate) eig_param.preserve_deflation = j < Nsrc - Nsrc_tile ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; + invertMultiSrcQuda(_hp_x.data(), _hp_b.data(), &inv_param); + + // move residuals to (i+j)^th location for verification after solves have finished + for (int i = 0; i < Nsrc_tile; i++) { + inv_param.true_res[j + i] = inv_param.true_res[i]; + inv_param.true_res_hq[j + i] = inv_param.true_res_hq[i]; + } + + quda::comm_allreduce_int(inv_param.iter); + inv_param.iter /= comm_size() / num_sub_partition; + quda::comm_allreduce_sum(inv_param.gflops); + inv_param.gflops /= comm_size() / num_sub_partition; + quda::comm_allreduce_max(inv_param.secs); + printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition, + inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile); + if (inv_param.energy > 0) { + printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n\n", + inv_param.energy, inv_param.energy / Nsrc_tile, inv_param.power, inv_param.temp, inv_param.clock); + } } - // Run split grid - invertMultiSrcQuda(_hp_x.data(), _hp_b.data(), &inv_param); - - quda::comm_allreduce_int(inv_param.iter); - inv_param.iter /= comm_size() / num_sub_partition; - quda::comm_allreduce_sum(inv_param.gflops); - inv_param.gflops /= comm_size() / num_sub_partition; - quda::comm_allreduce_max(inv_param.secs); - printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops\n\n", num_sub_partition, inv_param.iter, - inv_param.secs, inv_param.gflops / inv_param.secs); } // Free the multigrid solver if (inv_multigrid) destroyMultigridQuda(mg_preconditioner); // Compute timings - if (Nsrc > 1 && !use_split_grid) performanceStats(time, gflops, iter); + if (!use_multi_src) performanceStats(time, gflops, iter); std::vector> res(Nsrc); // Perform host side verification of inversion if requested @@ -418,7 +452,7 @@ std::vector> solve(test_t param) = {out_multishift.begin() + n * multishift, out_multishift.begin() + (n + 1) * multishift}; res[n] = verifyStaggeredInversion(in[n], out_subset, cpuFatQDP, cpuLongQDP, inv_param); } else { - res[n] = verifyStaggeredInversion(in[n], out[n], cpuFatQDP, cpuLongQDP, inv_param); + res[n] = verifyStaggeredInversion(in[n], out[n], cpuFatQDP, cpuLongQDP, inv_param, n); } } } diff --git a/tests/test.h b/tests/test.h index 7b726433a3..27552f0c10 100644 --- a/tests/test.h +++ b/tests/test.h @@ -2,6 +2,7 @@ #include #include #include +#include "asan.h" struct quda_test { diff --git a/tests/unitarize_link_test.cpp b/tests/unitarize_link_test.cpp index 4cd8553fdd..f6cf4c642d 100644 --- a/tests/unitarize_link_test.cpp +++ b/tests/unitarize_link_test.cpp @@ -9,6 +9,7 @@ #include "host_utils.h" #include #include "misc.h" +#include "test.h" #include "util_quda.h" #include "llfat_quda.h" #include diff --git a/tests/utils/command_line_params.cpp b/tests/utils/command_line_params.cpp index a16088a982..43d2745035 100644 --- a/tests/utils/command_line_params.cpp +++ b/tests/utils/command_line_params.cpp @@ -66,6 +66,7 @@ int pipeline = 0; int solution_accumulator_pipeline = 0; int test_type = 0; quda::mgarray nvec = {}; +quda::mgarray nvec_batch = {}; quda::mgarray mg_vec_infile; quda::mgarray mg_vec_outfile; quda::mgarray mg_vec_partfile = {}; @@ -212,6 +213,7 @@ QudaMemoryType mem_type_ritz = QUDA_MEMORY_DEVICE; // Parameters for the stand alone eigensolver int eig_ortho_block_size = 0; +int eig_evals_batch_size = 16; int eig_block_size = 4; int eig_n_ev = 16; int eig_n_kr = 32; @@ -249,6 +251,7 @@ bool eig_partfile = false; // all others are for PR vectors. quda::mgarray mg_eig = {}; quda::mgarray mg_eig_ortho_block_size = {}; +quda::mgarray mg_eig_evals_batch_size = {}; quda::mgarray mg_eig_block_size = {}; quda::mgarray mg_eig_n_ev_deflate = {}; quda::mgarray mg_eig_n_ev = {}; @@ -788,6 +791,8 @@ void add_eigen_option_group(std::shared_ptr quda_app) opgroup->add_option("--eig-ortho-block-size", eig_ortho_block_size, "The block size to use when orthonormalising vectors in hybrid modified Gram-Schmidt" "0 for always Classical, 1 for Modified, n > 1 for Hybrid)"); + opgroup->add_option("--eig-evals-batch-size", eig_evals_batch_size, + "The batch size used when computing eigenvalues in the eigensolver"); opgroup->add_option("--eig-block-size", eig_block_size, "The block size to use in the block variant eigensolver"); opgroup->add_option( "--eig-n-ev-deflate", eig_n_ev_deflate, @@ -932,6 +937,8 @@ void add_multigrid_option_group(std::shared_ptr quda_app) "Use Eigen to eigensolve the upper Hessenberg in IRAM, else use QUDA's QR code. (default true)"); quda_app->add_mgoption(opgroup, "--mg-eig-ortho-block-size", mg_eig_ortho_block_size, CLI::Validator(), "The block size to use when orthonormalising vectors in hybrid modified Gram-Schmidt"); + quda_app->add_mgoption(opgroup, "--mg-eig-evals-batch-size", mg_eig_evals_batch_size, CLI::Validator(), + "The block size used when computing eigenvalues in the eigensolver"); quda_app->add_mgoption(opgroup, "--mg-eig-block-size", mg_eig_block_size, CLI::Validator(), "The block size to use in the block variant eigensolver"); quda_app->add_mgoption(opgroup, "--mg-eig-n-ev", mg_eig_n_ev, CLI::Validator(), @@ -1008,6 +1015,9 @@ void add_multigrid_option_group(std::shared_ptr quda_app) "The number of pre-smoother applications to do at a given multigrid level (default 2)"); quda_app->add_mgoption(opgroup, "--mg-nvec", nvec, CLI::PositiveNumber, "Number of null-space vectors to define the multigrid transfer operator on a given level"); + quda_app->add_mgoption( + opgroup, "--mg-nvec-batch", nvec_batch, + CLI::PositiveNumber, "Batch size to use when computing the null-space vectors to define the multigrid transfer operator on a given level"); opgroup->add_option("--mg-oblique-proj-check", oblique_proj_check, "Measure how well the null vector subspace adjusts the low eigenmode subspace (default false)"); opgroup->add_option("--mg-omega", omega, diff --git a/tests/utils/command_line_params.h b/tests/utils/command_line_params.h index 2bc8b588ef..338dcc73bc 100644 --- a/tests/utils/command_line_params.h +++ b/tests/utils/command_line_params.h @@ -325,6 +325,7 @@ extern int pipeline; extern int solution_accumulator_pipeline; extern int test_type; extern quda::mgarray nvec; +extern quda::mgarray nvec_batch; extern quda::mgarray mg_vec_infile; extern quda::mgarray mg_vec_outfile; extern quda::mgarray mg_vec_partfile; @@ -461,6 +462,7 @@ extern QudaMemoryType mem_type_ritz; // Parameters for the stand alone eigensolver extern int eig_ortho_block_size; +extern int eig_evals_batch_size; extern int eig_block_size; extern int eig_n_ev; extern int eig_n_kr; @@ -498,6 +500,7 @@ extern bool eig_partfile; // all others are for PR vectors. extern quda::mgarray mg_eig; extern quda::mgarray mg_eig_ortho_block_size; +extern quda::mgarray mg_eig_evals_batch_size; extern quda::mgarray mg_eig_block_size; extern quda::mgarray mg_eig_n_ev_deflate; extern quda::mgarray mg_eig_n_ev; diff --git a/tests/utils/host_utils.cpp b/tests/utils/host_utils.cpp index 9df5050cc1..064b64dd71 100644 --- a/tests/utils/host_utils.cpp +++ b/tests/utils/host_utils.cpp @@ -395,7 +395,7 @@ void initComms(int, char **, int *const commDims) #if defined(QMP_COMMS) QMP_thread_level_t tl; - QMP_init_msg_passing(&argc, &argv, QMP_THREAD_SINGLE, &tl); + QMP_init_msg_passing(&argc, &argv, QMP_THREAD_FUNNELED, &tl); // make sure the QMP logical ordering matches QUDA's if (rank_order == 0) { @@ -406,7 +406,15 @@ void initComms(int, char **, int *const commDims) QMP_declare_logical_topology_map(commDims, 4, map, 4); } #elif defined(MPI_COMMS) - MPI_Init(&argc, &argv); + int provided = 0; + int required = MPI_THREAD_FUNNELED; + int flag = MPI_Init_thread(&argc, &argv, required, &provided); + + if (provided != required) { + printf("%s: required thread-safety level %d can't be provided %d\n", __func__, required, provided); + fflush(stdout); + exit(flag); + } #endif QudaCommsMap func = rank_order == 0 ? lex_rank_from_coords_t : lex_rank_from_coords_x; @@ -763,16 +771,6 @@ int fullLatticeIndex(int i, int oddBit) return X; } -extern "C" { -/** - @brief Set the default ASAN options. This ensures that QUDA just - works when SANITIZE is enabled without requiring ASAN_OPTIONS to be - set. We default disable leak checking, otherwise this will cause - ctest to fail with MPI library leaks. - */ -const char *__asan_default_options() { return "detect_leaks=0,protect_shadow_gap=0"; } -} - /** * For MPI, the default node mapping is lexicographical with t varying fastest. */ diff --git a/tests/utils/set_params.cpp b/tests/utils/set_params.cpp index 3ba1900cf0..7dd43a651a 100644 --- a/tests/utils/set_params.cpp +++ b/tests/utils/set_params.cpp @@ -142,6 +142,9 @@ void setInvertParam(QudaInvertParam &inv_param) // Use 3D or 4D laplace inv_param.laplace3D = laplace3D; + if (Nsrc < Nsrc_tile || Nsrc % Nsrc_tile != 0) + errorQuda("Invalid combination Nsrc = %d Nsrc_tile = %d", Nsrc, Nsrc_tile); + // Some fermion specific parameters if (dslash_type == QUDA_TWISTED_MASS_DSLASH || dslash_type == QUDA_TWISTED_CLOVER_DSLASH) { inv_param.mu = mu; @@ -332,6 +335,7 @@ void setEigParam(QudaEigParam &eig_param) } eig_param.ortho_block_size = eig_ortho_block_size; + eig_param.compute_evals_batch_size = eig_evals_batch_size; eig_param.block_size = (eig_param.eig_type == QUDA_EIG_TR_LANCZOS || eig_param.eig_type == QUDA_EIG_IR_ARNOLDI) ? 1 : eig_block_size; eig_param.n_ev = eig_n_ev; @@ -471,6 +475,7 @@ void setMultigridParam(QudaMultigridParam &mg_param) mg_param.spin_block_size[i] = 1; mg_param.n_vec[i] = nvec[i] == 0 ? 24 : nvec[i]; // default to 24 vectors if not set + mg_param.n_vec_batch[i] = nvec_batch[i] == 0 ? 1 : nvec_batch[i]; // default to batch size 1 if not set mg_param.n_block_ortho[i] = n_block_ortho[i]; // number of times to Gram-Schmidt mg_param.block_ortho_two_pass[i] = block_ortho_two_pass[i] ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; // whether to use a two-pass block ortho @@ -789,6 +794,8 @@ void setMultigridEigParam(QudaEigParam &mg_eig_param, int level) mg_eig_param.n_ev = mg_eig_n_ev[level]; mg_eig_param.n_kr = mg_eig_n_kr[level]; mg_eig_param.n_conv = nvec[level]; + mg_eig_param.compute_evals_batch_size + = mg_eig_evals_batch_size[level] ? mg_eig_evals_batch_size[level] : eig_evals_batch_size; // Inverters will deflate only this number of vectors. if (mg_eig_n_ev_deflate[level] > 0 && mg_eig_n_ev_deflate[level] < mg_eig_param.n_conv) @@ -931,6 +938,9 @@ void setStaggeredInvertParam(QudaInvertParam &inv_param) inv_param.kappa = kappa = 1.0 / (8.0 + mass); // for Laplace operator inv_param.laplace3D = laplace3D; // for Laplace operator + if (Nsrc < Nsrc_tile || Nsrc % Nsrc_tile != 0) + errorQuda("Invalid combination Nsrc = %d Nsrc_tile = %d", Nsrc, Nsrc_tile); + // outer solver parameters inv_param.inv_type = inv_type; inv_param.tol = tol; @@ -941,8 +951,10 @@ void setStaggeredInvertParam(QudaInvertParam &inv_param) inv_param.use_sloppy_partial_accumulator = false; inv_param.solution_accumulator_pipeline = solution_accumulator_pipeline; inv_param.pipeline = pipeline; + inv_param.max_res_increase = max_res_increase; + inv_param.max_res_increase_total = max_res_increase_total; - inv_param.Ls = 1; // Nsrc + inv_param.Ls = 1; if (tol_hq == 0 && tol == 0) { errorQuda("qudaInvert: requesting zero residual\n"); @@ -1078,6 +1090,7 @@ void setStaggeredMultigridParam(QudaMultigridParam &mg_param) mg_param.spin_block_size[i] = 1; mg_param.n_vec[i] = nvec[i] == 0 ? 64 : nvec[i]; // default to 64 vectors if not set + mg_param.n_vec_batch[i] = nvec_batch[i] == 0 ? 1 : nvec_batch[i]; // default to batch size 1 if not set mg_param.n_block_ortho[i] = n_block_ortho[i]; // number of times to Gram-Schmidt mg_param.block_ortho_two_pass[i] = block_ortho_two_pass[i] ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; // whether to use a two-pass block ortho