diff --git a/CMakeLists.txt b/CMakeLists.txt index 29643264fc..7368db7a03 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -448,8 +448,8 @@ set(SeQuant_optimize_src SeQuant/core/optimize/common_subexpression_elimination.hpp SeQuant/core/optimize/cost_model.hpp SeQuant/core/optimize/extract_subtrees.hpp - SeQuant/core/optimize/fusion.cpp - SeQuant/core/optimize/fusion.hpp + SeQuant/core/optimize/multiterm.cpp + SeQuant/core/optimize/multiterm.hpp SeQuant/core/optimize/options.hpp SeQuant/core/optimize/optimize.cpp SeQuant/core/optimize/optimize.hpp diff --git a/SeQuant/core/optimize/fusion.cpp b/SeQuant/core/optimize/fusion.cpp deleted file mode 100644 index 680ce44c56..0000000000 --- a/SeQuant/core/optimize/fusion.cpp +++ /dev/null @@ -1,135 +0,0 @@ -#include - -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace sequant::opt { - -using ranges::views::drop; -using ranges::views::reverse; -using ranges::views::zip; - -// convert a Product of single tensor and scalar == 1 into a tensor exprptr -auto lift_tensor = [](Product const& p) -> ExprPtr { - return p.scalar() == 1 && p.size() == 1 - ? p.factor(0) - : ex(Product{p.scalar(), p.factors().begin(), - p.factors().end(), Product::Flatten::No}); -}; - -Fusion::Fusion(Product const& lhs, Product const& rhs) - : left_{fuse_left(lhs, rhs)}, right_{fuse_right(lhs, rhs)} {} - -ExprPtr Fusion::left() const { return left_; } - -ExprPtr Fusion::right() const { return right_; } - -ExprPtr Fusion::fuse_left(Product const& lhs, Product const& rhs) { - auto fac = container::svector{}; - - for (auto&& [e1, e2] : zip(lhs.factors(), rhs.factors())) { - if (e1 == e2) - fac.push_back(e1); - else - break; - } - - if (fac.empty()) return nullptr; - - auto lsmand = lhs.factors() | drop(fac.size()); - auto rsmand = rhs.factors() | drop(fac.size()); - - auto fac_prod = Product{fac.begin(), fac.end()}; - auto lsmand_prod = Product{lsmand.begin(), lsmand.end()}; - auto rsmand_prod = Product{rsmand.begin(), rsmand.end()}; - - SEQUANT_ASSERT(lhs.scalar().imag().is_zero() && - rhs.scalar().imag().is_zero() && - "Complex valued gcd not supported"); - auto scalars_fused = fuse_scalar(lhs.scalar().real(), rhs.scalar().real()); - - fac_prod.scale(scalars_fused.at(0)); - lsmand_prod.scale(scalars_fused.at(1)); - rsmand_prod.scale(scalars_fused.at(2)); - - // f (a + b) - - auto f = lift_tensor(fac_prod); - auto a = lift_tensor(lsmand_prod); - auto b = lift_tensor(rsmand_prod); - - return ex(ExprPtrList{f, ex(ExprPtrList{a, b})}); -} - -ExprPtr Fusion::fuse_right(Product const& lhs, Product const& rhs) { - auto fac = container::svector{}; - - for (auto&& [e1, e2] : - zip(lhs.factors() | reverse, rhs.factors() | reverse)) { - if (e1 == e2) - fac.push_back(e1); - else - break; - } - - if (fac.empty()) return nullptr; - - ranges::reverse(fac); - auto lsmand = lhs.factors() | reverse | drop(fac.size()) | reverse; - auto rsmand = rhs.factors() | reverse | drop(fac.size()) | reverse; - - auto fac_prod = Product{fac.begin(), fac.end()}; - auto lsmand_prod = Product{lsmand.begin(), lsmand.end()}; - auto rsmand_prod = Product{rsmand.begin(), rsmand.end()}; - - SEQUANT_ASSERT(lhs.scalar().imag().is_zero() && - rhs.scalar().imag().is_zero() && - "Complex valued gcd not supported"); - auto scalars_fused = fuse_scalar(lhs.scalar().real(), rhs.scalar().real()); - - fac_prod.scale(scalars_fused.at(0)); - lsmand_prod.scale(scalars_fused.at(1)); - rsmand_prod.scale(scalars_fused.at(2)); - - // (a + b) f - - auto a = lift_tensor(lsmand_prod); - auto b = lift_tensor(rsmand_prod); - auto f = lift_tensor(fac_prod); - - return ex(ExprPtrList{ex(ExprPtrList{a, b}), f}); -} - -rational Fusion::gcd_rational(rational const& left, rational const& right) { - auto&& r1 = left.real(); - auto&& r2 = right.real(); - auto&& n1 = numerator(r1); - auto&& d1 = denominator(r1); - auto&& n2 = numerator(r2); - auto&& d2 = denominator(r2); - - auto num = gcd(n1 * d2, n2 * d1); - return {num, d1 * d2}; -} - -std::array Fusion::fuse_scalar(rational const& left, - rational const& right) { - auto fused = gcd_rational(left, right); - rational left_fused = left / fused; - rational right_fused = right / fused; - if (left < 0 && right < 0) { - fused *= -1; - left_fused *= -1; - right_fused *= -1; - } - return {fused, left_fused, right_fused}; -} - -} // namespace sequant::opt diff --git a/SeQuant/core/optimize/fusion.hpp b/SeQuant/core/optimize/fusion.hpp deleted file mode 100644 index 210a9b1201..0000000000 --- a/SeQuant/core/optimize/fusion.hpp +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef SEQUANT_OPT_FUSION_HPP -#define SEQUANT_OPT_FUSION_HPP - -#include - -namespace sequant::opt { - -/// -/// Use this class to test the fusibility of two products. -/// -/// The pattern of factors should match in both products for them -/// to have a common factor. -/// Fusion is either possible from the left hand side or the right hand side. -/// -/// eg. abcd + abef => ab(cd + ef) from left. nullptr from right. -/// abef + cdef => nullptr from left. (ab + cd)ef from right. -/// -/// Only common scalars are factored out as of now. -/// eg. (1/2)abcd + (1/2)abef => (1/2)ab(cd + ef) -/// (1/2)abcd + (1/4)abef => ab((1/2)cd + (1/4)ef) -/// (1/2)abcd - (1/2)abef => ab((1/2)cd - (1/2)ef) - -class Fusion { - public: - Fusion(Product const& lhs, Product const& rhs); - - /// the result of fusion from left hand side. - /// returns nullptr if no fusion possible. - ExprPtr left() const; - - /// the result of fusion from right hand side. - /// returns nullptr if no fusion possible. - ExprPtr right() const; - - static ExprPtr fuse_left(Product const& lhs, Product const& rhs); - - static ExprPtr fuse_right(Product const& lhs, Product const& rhs); - - /// - /// Get the greatest common divisor of two rational numbers. - /// - static rational gcd_rational(rational const& left, rational const& right); - - /// - /// Fuse scalars @param left and @param right and the return the result - /// as an array of three elements: first is the greatest common factor, - /// second the fused sub-factor of @param left and the third is that - /// of @param right. - /// - static std::array fuse_scalar(rational const& left, - rational const& right); - - private: - ExprPtr left_; - - ExprPtr right_; -}; - -} // namespace sequant::opt - -#endif // SEQUANT_OPT_FUSION_HPP diff --git a/SeQuant/core/optimize/multiterm.cpp b/SeQuant/core/optimize/multiterm.cpp new file mode 100644 index 0000000000..2ea75affbc --- /dev/null +++ b/SeQuant/core/optimize/multiterm.cpp @@ -0,0 +1,595 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace sequant::opt { + +namespace { + +using Node = FullBinaryNode; +using index_vector = EvalExpr::index_vector; +using scalar_type = Constant::scalar_type; ///< Complex +using Signature = std::vector; ///< per-bucket signature key + +// =========================================================================== +// Cost model +// =========================================================================== + +/// Cost model for biclique scoring: owns the active OptimizeOptions and answers +/// the three size/cost questions the search asks. Constructed once per +/// factorize_multiterm call so the metric and extent map need not be threaded +/// through every search signature. +/// +/// \note Distinct from \ref detail::CostModel (the \c cost_model.hpp concept +/// for the single-term DP's \c AdditiveModel / \c PeakModel / +/// \c PeakBatchedModel): that optimizes contraction order within one term, +/// this scores cross-summand factoring. \c contraction_cost already reuses +/// \ref detail::flops_counter / \ref detail::memsize_counter, and +/// \c tensor_size re-derives \ref detail::footprint_counter's element count +/// under a different convention -- unifying the two is deferred until Stage 3 +/// (dummy-invariant matching) settles this class's interface. +class BicliqueCostModel { + OptimizeOptions const& opts_; + + public: + explicit BicliqueCostModel(OptimizeOptions const& o) : opts_(o) {} + + /// Product of index extents -- the number of elements of a tensor with the + /// given result modes (1 for a scalar / empty index list). + double tensor_size(index_vector const& idxs) const { + double s = 1.; + for (Index const& ix : idxs) + s *= static_cast(opts_.idx_to_extent(ix)); + return s; + } + + /// Cost of the single binary contraction L*R -> result, under the active + /// metric. This is the same base counter single-term optimization uses + /// (\ref detail::flops_counter / \ref detail::memsize_counter), so the two + /// passes agree on the base cost of a contraction; single-term's volatile- + /// and footprint-weighting is not applied here (see the class \note above). + double contraction_cost(index_vector const& l, index_vector const& r, + index_vector const& res) const { + if (opts_.objective_function == ObjectiveFunction::DenseSize) + return detail::memsize_counter(opts_.idx_to_extent)(l, r, res); + return detail::flops_counter(opts_.idx_to_extent)(l, r, res); + } + + /// Saving of an m x n biclique that rewrites m*n contractions L_i*R_j as one + /// (sum_i L_i)*(sum_j R_j) plus the two factor-sums. + /// + /// \note The coefficient is (m*n - 1): the fold replaces m*n contractions + /// with a single (sum L)*(sum R), so m*n - 1 contractions are avoided; from + /// that we subtract the two factor-sum build costs. Result-add savings in the + /// original sum are ignored, i.e. this is a conservative lower bound. + double saving(std::size_t m, std::size_t n, double c_final, double l_size, + double r_size) const { + double const avoided = (static_cast(m * n) - 1.) * c_final; + double build = 0.; + if (m > 1) build += static_cast(m - 1) * l_size; + if (n > 1) build += static_cast(n - 1) * r_size; + return avoided - build; + } +}; + +// =========================================================================== +// Forest -> bipartite graph +// =========================================================================== + +/// Label-exact, phase-relaxed factor equality. Two factor subtrees match iff +/// the structural/connectivity comparator agrees and their canonical result +/// indices are label-identical; a difference in canonical phase does not block +/// the match. Two factors that differ only by a canonicalization sign share one +/// vertex; the relative sign is folded into the edge coefficient (and thence +/// into a partner) at emission. canon_indices equality is retained: every +/// partner still exposes the same labeled free indices, so the factor-sums +/// remain summable with no relabeling (dummy-invariant matching is not handled +/// here). +struct FactorHash { + std::size_t operator()(Node const* n) const { + return TreeNodeHasher{}(*n); + } +}; +struct FactorEq { + bool operator()(Node const* a, Node const* b) const { + if (!TreeNodeEqualityComparator{}(*a, *b)) return false; + index_vector const& ia = (*a)->canon_indices(); + index_vector const& ib = (*b)->canon_indices(); + if (ia.size() != ib.size()) return false; + for (std::size_t k = 0; k < ia.size(); ++k) + if (!(ia[k] == ib[k])) return false; + return true; // phase relaxed + } +}; + +/// Interns factor subtrees into dense vertex ids 0..V-1 under the label-exact, +/// phase-relaxed predicate: a structurally identical factor gets +/// the same id whether it sits on the left or the right of a contraction and +/// regardless of its canonicalization sign. The first factor interned for a +/// vertex is its representative; \ref phase records that representative's +/// canonicalization phase so the per-edge relative sign can be recovered (see +/// \ref build_buckets). Ids are dense and assignment-ordered, so the reps and +/// phases are kept in flat vectors indexed by id. +class Interner { + std::unordered_map vmap_; + std::vector reps_; ///< dense: id -> representative node + std::vector phases_; ///< dense: id -> representative phase + + public: + /// Vertex id of \p n, assigning a fresh dense id (and recording \p n as the + /// representative, with its canonicalization phase) on first sight. + int intern(Node const& n) { + auto [it, inserted] = vmap_.try_emplace(&n, static_cast(vmap_.size())); + if (inserted) { + reps_.push_back(&n); + phases_.push_back((*n).canon_phase()); + } + return it->second; + } + Node const* rep(int id) const { return reps_[id]; } + std::int8_t phase(int id) const { return phases_[id]; } +}; + +/// True if \p node is a top-level binary contraction of two tensor subtrees. +/// Only a tensor*tensor contraction carries a connectivity graph. +bool is_splittable(Node const& node) { + return !node.leaf() && node->op_type() == EvalOp::Product && + node->has_connectivity_graph(); +} + +/// The contraction at the heart of a summand together with its scalar +/// prefactor. A plain tensor*tensor summand binarizes straight to a splittable +/// contraction (coeff == 1). A scalar-prefactored summand binarizes to +/// Product(contraction, Constant) with a null connectivity graph at the root; +/// we peel the Constant off and recurse to the contraction one level down so +/// that \c A*B - A*C and \c 2*A*B + 3*A*C participate in folding. Returns +/// \c std::nullopt for leaves, scalar*single-tensor, and +/// non-Constant (e.g. Variable) prefactors -- those pass through untouched. +struct PrefactoredContraction { + Node const* node = nullptr; ///< the splittable contraction + scalar_type coeff{1}; ///< summand = coeff * to_expr(node) +}; +std::optional extract_core(Node const& root) { + if (is_splittable(root)) return PrefactoredContraction{&root, scalar_type{1}}; + if (root.leaf() || root->op_type() != EvalOp::Product) return std::nullopt; + // Product(contraction, scalar) -- scalar rides as a Constant child, the + // contraction as the other (order is binarize's choice, so try both). + auto try_pair = + [](Node const& core, + Node const& scal) -> std::optional { + if (scal->is_scalar() && scal->expr() && scal->expr()->is() && + is_splittable(core)) + return PrefactoredContraction{&core, + scal->expr()->as().value()}; + return std::nullopt; + }; + if (auto c = try_pair(root.left(), root.right())) return c; + if (auto c = try_pair(root.right(), root.left())) return c; + return std::nullopt; +} + +/// Signature bucket key: the sorted multiset of index-space identifiers over +/// the final contraction's external and contracted indices (spaces only, +/// label-agnostic). Terms only factor together within a bucket; cross-bucket +/// bicliques are impossible by construction. +Signature signature(Node const& root) { + index_vector const& ext = root->canon_indices(); + index_vector const& l = root.left()->canon_indices(); + + auto push_space = [](Signature& v, Index const& ix) { + v.push_back(static_cast(ix.space().attr())); + }; + + Signature key; + for (Index const& ix : ext) push_space(key, ix); + auto is_external = [&ext](Index const& ix) { + for (Index const& e : ext) + if (e == ix) return true; + return false; + }; + // Each contracted index appears on both L and R; scan L only to count once. + for (Index const& ix : l) + if (!is_external(ix)) push_space(key, ix); + std::sort(key.begin(), key.end()); + return key; +} + +/// One splittable summand: an edge between its left and right factor vertices, +/// carrying the per-summand coefficient that relates the summand's value to the +/// product of the two vertex representatives' expressions: +/// summand = coeff * to_expr(left_rep) * to_expr(right_rep). +/// coeff absorbs the scalar prefactor (\ref PrefactoredContraction) and the +/// relative canonicalization sign of each factor versus its vertex +/// representative. +struct Edge { + std::size_t pos; ///< position in sum.summands() + int lvid, rvid; ///< vertex ids of the left / right child factors + scalar_type coeff; ///< summand = coeff * L_rep * R_rep +}; + +struct Bucket { + double c_final = 0.; ///< memoized final-contraction cost (signature-only) + std::vector edges; +}; + +/// The live bipartite view of a bucket over not-yet-consumed edges. Parallel +/// edges (same vertex pair, e.g. duplicate summands) have their coefficients +/// summed and their positions concatenated. Built once per greedy round via +/// \ref build; the search reads it only through the accessors below, so the +/// underlying maps stay an implementation detail. +class Live { + std::map> left_adj_; ///< lvid -> sorted rvids + std::map left_rep_, right_rep_; + std::map, std::vector> edge_pos_; + std::map, scalar_type> edge_coeff_; + + public: + static Live build(Bucket const& bucket, std::vector const& consumed, + Interner const& interner); + + /// lvid -> sorted rvids, over the live (not-yet-consumed) edges. + std::map> const& left_adj() const { return left_adj_; } + Node const* left_rep(int l) const { return left_rep_.at(l); } + Node const* right_rep(int r) const { return right_rep_.at(r); } + /// Summand positions covered by the (l, r) edge. Defined only for present + /// edges -- throws on an absent vertex pair. + std::vector const& positions(int l, int r) const { + return edge_pos_.at({l, r}); + } + /// Folded coefficient of the (l, r) edge. Defined only for present edges. + scalar_type coeff(int l, int r) const { return edge_coeff_.at({l, r}); } +}; + +Live Live::build(Bucket const& bucket, std::vector const& consumed, + Interner const& interner) { + Live live; + std::map> adj; + for (Edge const& e : bucket.edges) { + if (consumed[e.pos]) continue; + adj[e.lvid].insert(e.rvid); + live.left_rep_.emplace(e.lvid, interner.rep(e.lvid)); + live.right_rep_.emplace(e.rvid, interner.rep(e.rvid)); + auto key = std::pair{e.lvid, e.rvid}; + live.edge_pos_[key].push_back(e.pos); + auto [it, inserted] = live.edge_coeff_.try_emplace(key, e.coeff); + if (!inserted) it->second = it->second + e.coeff; + } + for (auto const& [l, rs] : adj) + live.left_adj_.emplace(l, std::vector(rs.begin(), rs.end())); + return live; +} + +/// Build the per-signature buckets of factorable edges from the binarized +/// summands, interning factor vertices into \p interner along the way. +/// +/// Each splittable summand contributes one edge between its left and right +/// factor vertices. The edge coefficient relates the summand's value to the +/// product of the two vertex representatives' expressions: +/// summand = coeff * to_expr(L_rep) * to_expr(R_rep), +/// coeff = prefactor * sigma_L * sigma_R, +/// sigma = canon_phase(factor) * canon_phase(representative) (each +/-1), +/// so the relative canonicalization sign between a factor and its vertex +/// representative rides on the edge. Non-splittable +/// summands (leaf, scalar*leaf, opaque prefactor) contribute no edge and pass +/// through untouched. \c c_final is memoized per bucket on its first edge +/// (signature-only, hence equal for every member). +std::map build_buckets(container::vector const& nodes, + Interner& interner, + BicliqueCostModel const& cost) { + std::map buckets; + for (std::size_t i = 0; i < nodes.size(); ++i) { + auto core = extract_core(nodes[i]); + if (!core) continue; // leaf / scalar*leaf / opaque prefactor: untouched + Node const& c = *core->node; + Node const& lnode = c.left(); + Node const& rnode = c.right(); + int const lv = interner.intern(lnode), rv = interner.intern(rnode); + scalar_type const sigma_l(static_cast(lnode->canon_phase()) * + static_cast(interner.phase(lv))); + scalar_type const sigma_r(static_cast(rnode->canon_phase()) * + static_cast(interner.phase(rv))); + scalar_type const coeff = core->coeff * sigma_l * sigma_r; + + auto& bucket = buckets[signature(c)]; + if (bucket.edges.empty()) + bucket.c_final = cost.contraction_cost( + lnode->canon_indices(), rnode->canon_indices(), c->canon_indices()); + bucket.edges.push_back(Edge{i, lv, rv, coeff}); + } + return buckets; +} + +// =========================================================================== +// Biclique search +// =========================================================================== + +std::vector intersect_sorted(std::vector const& a, + std::vector const& b) { + std::vector out; + std::set_intersection(a.begin(), a.end(), b.begin(), b.end(), + std::back_inserter(out)); + return out; +} + +/// A factorable biclique S_L x S_R within one bucket's live bipartite graph. +/// The coefficient of each covered summand factors as left_coeffs[i] * +/// right_coeffs[j] (rank-1), which is what lets it be emitted as the single +/// product (sum_i left_coeffs[i] L_i) * (sum_j right_coeffs[j] R_j). +struct Biclique { + std::vector left, right; ///< vertex ids + std::vector positions; ///< covered summand positions + std::vector left_reps, right_reps; + std::vector left_coeffs, right_coeffs; + double saving = 0.; + std::size_t min_pos = 0; ///< for deterministic tie-breaking +}; + +/// Strict "candidate beats incumbent" order: higher saving wins; ties broken by +/// lowest covered position (deterministic). A null incumbent is always beaten. +bool supersedes(Biclique const& cand, std::optional const& best) { + return !best || cand.saving > best->saving || + (cand.saving == best->saving && cand.min_pos < best->min_pos); +} + +/// Try to factor the m x n coefficient matrix c[i][j] = edge_coeff(left[i], +/// right[j]) as an outer product alpha[i] * beta[j] (multiplicative rank 1). +/// Returns nullopt if it does not factor (then the biclique must be reduced to +/// a one-sided fold, which is always factorable). For a complete biclique every +/// (left[i], right[j]) edge is present, so every lookup succeeds. +std::optional, std::vector>> +factor_coeffs(std::vector const& left, std::vector const& right, + Live const& live) { + auto at = [&](int l, int r) { return live.coeff(l, r); }; + scalar_type const c00 = at(left.front(), right.front()); + if (c00.is_zero()) return std::nullopt; + // rank-1 iff c[i][j] * c[0][0] == c[i][0] * c[0][j] for all i, j. + for (std::size_t i = 0; i < left.size(); ++i) + for (std::size_t j = 0; j < right.size(); ++j) + if (!(at(left[i], right[j]) * c00 == + at(left[i], right.front()) * at(left.front(), right[j]))) + return std::nullopt; + std::vector alpha, beta; + alpha.reserve(left.size()); + beta.reserve(right.size()); + for (int l : left) alpha.push_back(at(l, right.front())); + for (int r : right) beta.push_back(at(left.front(), r) / c00); + return std::pair{std::move(alpha), std::move(beta)}; +} + +/// Fill the geometry (positions, reps, min_pos) of a biclique over the given +/// vertex sets, then score it. Coefficients (\p alpha, \p beta) are supplied by +/// the caller (already validated factorable). +/// +/// \pre \p left x \p right is a complete biclique: every (l, r) pair is a live +/// edge. This walks the full product through \ref Live::positions, which throws +/// on an absent pair. Both callers honor it -- \ref best_biclique passes +/// maximal complete vertex sets, and \ref best_fold's one-sided slices are a +/// single row or column of one. +Biclique make_biclique(std::vector left, std::vector right, + std::vector alpha, + std::vector beta, double c_final, + Live const& live, BicliqueCostModel const& cost) { + Biclique bc; + std::size_t const m = left.size(), n = right.size(); + double const l_size = + cost.tensor_size((*live.left_rep(left.front()))->canon_indices()); + double const r_size = + cost.tensor_size((*live.right_rep(right.front()))->canon_indices()); + bc.saving = cost.saving(m, n, c_final, l_size, r_size); + + std::size_t min_pos = std::numeric_limits::max(); + for (int l : left) + for (int r : right) + for (std::size_t p : live.positions(l, r)) { + bc.positions.push_back(p); + min_pos = std::min(min_pos, p); + } + bc.min_pos = min_pos; + for (int l : left) bc.left_reps.push_back(live.left_rep(l)); + for (int r : right) bc.right_reps.push_back(live.right_rep(r)); + bc.left = std::move(left); + bc.right = std::move(right); + bc.left_coeffs = std::move(alpha); + bc.right_coeffs = std::move(beta); + return bc; +} + +/// The best factorable fold derivable from the complete biclique \p left x +/// \p right. If the coefficient matrix is rank-1 the full m x n fold is used; +/// otherwise it is reduced to the more profitable of a one-sided slice -- a +/// single left vertex against all of \p right, or all of \p left against a +/// single right vertex -- which is always factorable (the relative signs ride +/// entirely on the partner side). Returns nullopt if nothing is profitable. +std::optional best_fold(std::vector const& left, + std::vector const& right, double c_final, + Live const& live, + BicliqueCostModel const& cost) { + std::optional best; + auto consider = [&](Biclique bc) { + if (bc.saving > 0. && supersedes(bc, best)) best = std::move(bc); + }; + + if (auto ab = factor_coeffs(left, right, live)) { + consider(make_biclique(left, right, std::move(ab->first), + std::move(ab->second), c_final, live, cost)); + } else { + // One-sided reductions (always rank-1). Coefficients fold onto the + // many-side partner: the single side keeps coefficient 1. + int const l0 = left.front(); + std::vector beta; + for (int r : right) beta.push_back(live.coeff(l0, r)); + consider(make_biclique({l0}, right, {scalar_type{1}}, std::move(beta), + c_final, live, cost)); + + int const r0 = right.front(); + std::vector alpha; + for (int l : left) alpha.push_back(live.coeff(l, r0)); + consider(make_biclique(left, {r0}, std::move(alpha), {scalar_type{1}}, + c_final, live, cost)); + } + return best; +} + +/// Internal safety backstop on the intersection-closure size -- deliberately a +/// file-local constant and not a field of \ref OptimizeOptions. It is not an +/// optimization tuning knob: the enumeration below is a first-cut, worst-case- +/// exponential closure (acceleration is deferred), and this merely caps its +/// memory/time on pathological inputs. Hitting it cannot change correctness: +/// the greedy driver re-enumerates every round, so a capped round only risks a +/// sub-optimal (never invalid) fold. There's no quality-for-speed trade a user +/// would sensibly dial. It sits far above any realistic problem size and so +/// never fires in practice; were it ever to fire, the remedy is the deferred +/// search acceleration, not a larger user-set bound. +constexpr std::size_t max_closure_size = 50000; + +/// Enumerate maximal bicliques of \p live and return the highest-saving fold +/// with positive saving (deterministic tie-break: lowest covered position). +/// Returns nullopt if none is profitable. First cut: full enumeration via +/// intersection-closure of the left neighborhoods (worst-case exponential; +/// acceleration is deferred). \p cap bounds the +/// closure; if hit we stop growing it and keep the best biclique found so far +/// (still correct -- the greedy driver re-enumerates each round, so a bounded +/// round only risks sub-optimality, never an invalid fold). +std::optional best_biclique(Live const& live, double c_final, + BicliqueCostModel const& cost, + std::size_t cap = max_closure_size) { + // Closure of the left neighborhoods under intersection -> candidate right + // vertex sets. Each closed set induces a maximal biclique. + std::set> closed; + std::vector> seeds; + for (auto const& [l, nb] : live.left_adj()) + if (!nb.empty() && closed.insert(nb).second) seeds.push_back(nb); + + std::vector> queue = seeds; + bool overflow = false; + for (std::size_t qi = 0; qi < queue.size() && !overflow; ++qi) { + for (std::vector const& sd : seeds) { + std::vector inter = intersect_sorted(queue[qi], sd); + if (!inter.empty() && closed.insert(inter).second) { + queue.push_back(inter); + if (closed.size() > cap) { + overflow = true; + break; + } + } + } + } + + std::optional best; + for (std::vector const& r0 : closed) { + // S_L = every left vertex adjacent to all of r0. + std::vector left; + for (auto const& [l, nb] : live.left_adj()) + if (std::includes(nb.begin(), nb.end(), r0.begin(), r0.end())) + left.push_back(l); + if (left.empty()) continue; + // S_R = intersection of the neighborhoods of S_L (>= r0): makes it maximal. + std::vector right = live.left_adj().at(left.front()); + for (std::size_t k = 1; k < left.size(); ++k) + right = intersect_sorted(right, live.left_adj().at(left[k])); + if (right.empty()) continue; + + std::optional bc = best_fold(left, right, c_final, live, cost); + if (!bc) continue; + if (supersedes(*bc, best)) best = std::move(bc); + } + return best; +} + +// =========================================================================== +// Emission +// =========================================================================== + +/// (sum of) the given factor subtrees, each scaled by its coefficient, as a +/// symbolic ExprPtr. A unit coefficient contributes the bare factor; a +/// non-unit one wraps it in a scalar Product. +ExprPtr side_expr(std::vector const& reps, + std::vector const& coeffs) { + auto term = [](Node const* n, scalar_type c) -> ExprPtr { + ExprPtr e = to_expr(*n); + if (c.is_identity()) return e; + return ex(c, ExprPtrList{e}, Product::Flatten::No); + }; + if (reps.size() == 1) return term(reps.front(), coeffs.front()); + Sum::summands_type parts; + parts.reserve(reps.size()); + for (std::size_t i = 0; i < reps.size(); ++i) + parts.push_back(term(reps[i], coeffs[i])); + return ex(Sum(parts.begin(), parts.end())); +} + +ExprPtr emit_biclique(Biclique const& bc) { + return ex(ExprPtrList{side_expr(bc.left_reps, bc.left_coeffs), + side_expr(bc.right_reps, bc.right_coeffs)}, + Product::Flatten::No); +} + +} // namespace + +ExprPtr factorize_multiterm( + Sum const& sum, container::vector> const& nodes, + OptimizeOptions const& opts) { + // Cost-driven selection needs index extents; guaranteed by construction in + // the normal optimize() flow. No structural fallback. + SEQUANT_ASSERT(opts.idx_to_extent); + SEQUANT_ASSERT(nodes.size() == sum.size()); + + BicliqueCostModel cost{opts}; + std::size_t const N = sum.size(); + + // Build: intern factor vertices and group splittable summands into + // per-signature buckets of factorable edges. + Interner interner; + std::map buckets = build_buckets(nodes, interner, cost); + + std::vector consumed(N, false); + std::vector folds; + + // Greedy cost-driven driver: repeatedly apply the single highest-saving + // maximal biclique across all buckets, until none has positive saving. + while (true) { + std::optional best; + for (auto const& [key, bucket] : buckets) { + Live live = Live::build(bucket, consumed, interner); + std::optional bc = best_biclique(live, bucket.c_final, cost); + if (!bc) continue; + if (supersedes(*bc, best)) best = std::move(bc); + } + if (!best) break; + + for (std::size_t p : best->positions) consumed[p] = true; + folds.push_back(emit_biclique(*best)); + } + + // Nothing profitable: return the same summands in the same order. + if (folds.empty()) return ex(sum); + + // Reassemble: untouched summands (original order) followed by folds. + Sum::summands_type out; + out.reserve(N + folds.size()); + for (std::size_t i = 0; i < N; ++i) + if (!consumed[i]) out.push_back(sum.summand(i)); + for (ExprPtr const& f : folds) out.push_back(f); + + return ex(Sum(out.begin(), out.end())); +} + +} // namespace sequant::opt diff --git a/SeQuant/core/optimize/multiterm.hpp b/SeQuant/core/optimize/multiterm.hpp new file mode 100644 index 0000000000..751b6235a4 --- /dev/null +++ b/SeQuant/core/optimize/multiterm.hpp @@ -0,0 +1,60 @@ +#ifndef SEQUANT_CORE_OPTIMIZE_MULTITERM_HPP +#define SEQUANT_CORE_OPTIMIZE_MULTITERM_HPP + +#include +#include +#include +#include + +namespace sequant { +class Sum; +} + +namespace sequant::opt { + +/// \brief Multi-term factorization over single-term-optimized summands. +/// +/// Pulls shared factors across the summands of \p sum (e.g. +/// \c A*B + A*C -> A*(B + C)), generalized to N terms at once and to the +/// two-sided biclique case \c (A+B)*(X+Y), and emits the result in place as a +/// nested \c ExprPtr (a \c Product whose factor is a \c Sum). Factorizations +/// are applied only when they lower the evaluation cost (cost-driven biclique +/// selection); structurally-shareable factors whose saving is non-positive are +/// left untouched. +/// +/// \par Scope (what is and isn't pulled out) +/// The matcher operates on the two factors of each summand's *top* binary +/// contraction, as fixed by single-term optimization. Two consequences: +/// - A shared *composite* factor is pulled out only when single-term +/// optimization already exposes it as one of those two top-level factors +/// (e.g. \c A*B*C + A*B*D folds to \c A*B*(C + D) because \c A*B is the +/// shared top factor). A composite that single-term brackets *away* into a +/// deeper subtree is not recovered: \c A*B*C*D + A*B*E*F is left as two +/// products when the chosen evaluation order buries \c A*B (e.g. +/// \c ((A*B)*C)*D), since the two top factors then differ. +/// - Only two-sided bicliques are emitted, and partner sums are not +/// recursively re-factored. An input that mathematically factors into three +/// or more groups, e.g. \c (A+B)(C+D)(E+F) expanded to eight terms, yields a +/// single two-sided fold such as \c (AC+AD+BC+BD)*(E+F), not the fully +/// nested form. +/// +/// \param sum The single-term-optimized sum to factor. +/// \param nodes Per-summand binarized eval nodes, positionally aligned with +/// \c sum.summands() (\c nodes[i] is the binary tree of +/// \c sum.summand(i)). Read-only inputs to the matcher. +/// \param opts Optimization options. \c opts.idx_to_extent must be populated +/// (asserted). The biclique \c saving is scored with the same base +/// per-contraction cost counter single-term optimization uses -- +/// \c flops_counter (DenseFLOPs) or \c memsize_counter +/// (DenseSize), selected by \c opts.objective_function. It does +/// not, however, apply single-term's \c volatile_weight (with \c +/// is_volatile_leaf) or \c footprint_weight adjustments: those +/// fields do not currently influence multi-term factorization. +/// \return The factored sum as an \c ExprPtr. +ExprPtr factorize_multiterm( + Sum const& sum, container::vector> const& nodes, + OptimizeOptions const& opts); + +} // namespace sequant::opt + +#endif // SEQUANT_CORE_OPTIMIZE_MULTITERM_HPP diff --git a/SeQuant/core/optimize/optimize.cpp b/SeQuant/core/optimize/optimize.cpp index 548768acb2..9154e8bbae 100644 --- a/SeQuant/core/optimize/optimize.cpp +++ b/SeQuant/core/optimize/optimize.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -222,18 +223,33 @@ ExprPtr optimize_impl(ExprPtr const& expr, OptimizeOptions const& opts, } Sum new_sum(std::move(new_smands), Sum::move_only_tag{}); - if (!reorder) return ex(std::move(new_sum)); - // Binarize once per optimized summand and hand the nodes to reorder() - // so they aren't re-built inside clusters(). NOTE: this runs sequentially - // by design -- see invariant (2) above. + bool const do_multiterm = opts.multiterm == MultiTermFactor::Enable; + if (!reorder && !do_multiterm) return ex(std::move(new_sum)); + + // Optional multi-term factorization first: it can merge summands + ExprPtr result; container::vector> nodes; - nodes.reserve(new_sum.size()); - // per-summand binarize for ordering only; positional head doesn't escape. - SEQUANT_PRAGMA_IGNORE_DEPRECATED_BEGIN - for (auto const& s : new_sum.summands()) nodes.push_back(binarize(s)); - SEQUANT_PRAGMA_IGNORE_DEPRECATED_END - return ex(opt::reorder(new_sum, nodes)); + if (do_multiterm) { + nodes.reserve(new_sum.size()); + // per-summand binarize; positional head doesn't escape. + SEQUANT_PRAGMA_IGNORE_DEPRECATED_BEGIN + for (auto const& s : new_sum.summands()) nodes.push_back(binarize(s)); + SEQUANT_PRAGMA_IGNORE_DEPRECATED_END + result = opt::factorize_multiterm(new_sum, nodes, opts); + } else { + result = ex(std::move(new_sum)); + } + + // reorder (independent of multiterm). An unchanged summand count means + // multiterm folded nothing, leaving `nodes` positionally valid, so reorder + // can reuse them; a fold shrinks the sum and forces a re-binarize. + if (reorder && result->is()) { + auto const& s = result->as(); + auto reuse_nodes = do_multiterm && s.size() == nodes.size(); + return ex(reuse_nodes ? opt::reorder(s, nodes) : opt::reorder(s)); + } + return result; } return expr->clone(); diff --git a/SeQuant/core/optimize/options.hpp b/SeQuant/core/optimize/options.hpp index bd3e82acdf..fe6accbfc5 100644 --- a/SeQuant/core/optimize/options.hpp +++ b/SeQuant/core/optimize/options.hpp @@ -57,6 +57,12 @@ struct CSEOptions { bool subnet = false; }; +/// Whether to perform multi-term factorization, pulling shared factors across +/// the summands of a Sum (\c A*B + A*C -> A*(B + C)) using a cost-driven +/// biclique search. Opt-in; disabled by default so existing output is +/// unchanged. +enum class MultiTermFactor { Disable, Enable }; + /// Roofline parameters for the peak objectives' secondary (tie-break) cost. /// When \c machine_balance > 0, the per-contraction tie-break cost becomes the /// roofline wall-time proxy \c max(flops, machine_balance * Q), where the data @@ -128,6 +134,10 @@ struct OptimizeOptions { /// enabling can reduce op counts at the cost of additional optimization time. CSEOptions CSE = {}; + /// Whether to pull shared factors across summands via cost-driven + /// multi-term factorization. Disabled by default. + MultiTermFactor multiterm = MultiTermFactor::Disable; + /// Caller-supplied Index to extent provider. If empty, defaults to /// \c IndexSpace::approximate_size(). index_to_extent_t idx_to_extent = {}; diff --git a/tests/integration/eval/calc.inp b/tests/integration/eval/calc.inp index 68e1386335..985613f485 100755 --- a/tests/integration/eval/calc.inp +++ b/tests/integration/eval/calc.inp @@ -5,9 +5,11 @@ spintrace no optimization single_term yes reuse_imeds yes +multi_term yes scf tightness normal log level 1 +print_exprs yes diff --git a/tests/integration/eval/calc_info.cpp b/tests/integration/eval/calc_info.cpp index 6006f7295c..c8c4ce98c6 100644 --- a/tests/integration/eval/calc_info.cpp +++ b/tests/integration/eval/calc_info.cpp @@ -40,6 +40,12 @@ CalcInfo make_calc_info(std::string_view config_file, auto const scf_opts = parser.opts_scf(); auto log_opts = parser.opts_log(); if (!output_file.empty()) log_opts.file = output_file.data(); + + // multi-term factorization is only in effect when single-term + // factorization is enabled + if (optm_opts.multi_term && !optm_opts.single_term) + std::wcerr << L"warning: multi_term ignored because single_term is off\n"; + auto const data_info = DataInfo{fock_or_eri_file, eri_or_fock_file}; return CalcInfo{eq_opts, optm_opts, scf_opts, log_opts, data_info}; } diff --git a/tests/integration/eval/calc_info.hpp b/tests/integration/eval/calc_info.hpp index 922664bd37..aea8ec57df 100644 --- a/tests/integration/eval/calc_info.hpp +++ b/tests/integration/eval/calc_info.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -22,6 +23,8 @@ #include #include +#include +#include namespace sequant::eval { @@ -96,23 +99,59 @@ struct CalcInfo { template [[nodiscard]] container::vector> node_(ExprPtr const& expr, size_t rank) const { - using ranges::views::transform; auto trimmed = tail_factor(expr); - auto tform_and_save = - transform([st = optm_opts.single_term](const auto& expr) { - // SCF reference path: per-term binarize for energy/residual building; - // the head's bra/ket layout is consumed by integration helpers that - // index by slot ordinal and don't depend on conventional layout. - SEQUANT_PRAGMA_IGNORE_DEPRECATED_BEGIN - return binarize(st ? optimize(expr) : expr); - SEQUANT_PRAGMA_IGNORE_DEPRECATED_END - }) | - ranges::to_vector; - if (trimmed.size() > 0) { - return *trimmed | tform_and_save; - } else { // corner case: trimmed is an atom (i.e. single tensor) - return ranges::views::single(trimmed) | tform_and_save; + + // Optimize the whole sum at once: reorder() and multi-term factorization + // both act ACROSS summands, so a per-term optimize() would defeat them. + // optimize() single-term-optimizes every summand internally and returns the + // (possibly reordered/factored) sum, whose top-level summands become the + // eval nodes below. + // + // NOTE: multi-term factorization is performed inside optimize(), so it is + // tied to single_term: it is honored only when single_term is on. When + // single_term is off, optimize() is not called and multi_term has no effect + // (the raw expression is binarized as-is). OptimizeOptions has no switch to + // run multi-term/reorder without single-term optimization, and tying the + // two keeps this example's option handling simple (see + // OptionsOptimization). + OptimizeOptions opts; + opts.multiterm = optm_opts.multi_term ? MultiTermFactor::Enable + : MultiTermFactor::Disable; + ExprPtr const optimized = + optm_opts.single_term ? optimize(trimmed, opts) : trimmed; + + // Each top-level summand of the result is one eval node; a non-Sum result + // (single-term or atomic equation) is a single summand. + container::vector summands; + if (optimized->is()) + for (auto const& s : *optimized) summands.push_back(s); + else + summands.push_back(optimized); + + // SCF reference path: per-summand binarize for energy/residual building; + // the head's bra/ket layout is consumed by integration helpers that index + // by slot ordinal and don't depend on conventional layout. + container::vector> nodes; + nodes.reserve(summands.size()); + SEQUANT_PRAGMA_IGNORE_DEPRECATED_BEGIN + for (auto const& s : summands) nodes.push_back(binarize(s)); + SEQUANT_PRAGMA_IGNORE_DEPRECATED_END + + // Print the optimized expressions once (one summand per line) in the + // deserialize text format, just before returning -- before any evaluation. + // A header reports the per-equation term counts before/after optimization + // (multi-term factorization can merge summands), a footer closes the block. + if (log_opts.print_exprs) { + const std::size_t nb = trimmed->is() ? trimmed->size() : 1; + const std::size_t na = summands.size(); + std::wcout << std::format( + L"===== R{} terms (# terms before/after opt {}/{}) ======\n", rank, + nb, na); + for (auto const& s : summands) std::wcout << serialize(s) << L'\n'; + std::wcout << L"------------\n"; } + + return nodes; } }; diff --git a/tests/integration/eval/options.cpp b/tests/integration/eval/options.cpp index 694ab8c9a7..074dcb4be1 100644 --- a/tests/integration/eval/options.cpp +++ b/tests/integration/eval/options.cpp @@ -63,7 +63,8 @@ void ParseOptionsEquations::update(std::string_view arg_name, std::string ParseOptionsOptimization::help() const { return bool_parser.help(single_term) + "\n" + bool_parser.help(reuse_imeds) + - "\n" + bool_parser.help(cache_leaves); + "\n" + bool_parser.help(cache_leaves) + "\n" + + bool_parser.help(multi_term); } OptionsOptimization ParseOptionsOptimization::opts() const { return opts_; } @@ -76,6 +77,8 @@ void ParseOptionsOptimization::update(std::string_view arg_name, opts_.reuse_imeds = bool_parser.parse(value); else if (arg_name == cache_leaves) opts_.cache_leaves = bool_parser.parse(value); + else if (arg_name == multi_term) + opts_.multi_term = bool_parser.parse(value); else throw detail::ErrorArgNameInvalid{arg_name.data()}; } @@ -106,14 +109,20 @@ void ParseOptionsSCF::update(std::string_view arg_name, } // unreachable } -std::string ParseOptionsLog::help() const { return level_parser.help(level); } +std::string ParseOptionsLog::help() const { + return level_parser.help(level) + "\n" + bool_parser.help(print_exprs); +} OptionsLog ParseOptionsLog::opts() const { return opts_; } void ParseOptionsLog::update(std::string_view arg_name, std::string_view value) { - if (arg_name != level) throw detail::ErrorArgNameInvalid{arg_name.data()}; - opts_.level = level_parser.parse(value); + if (arg_name == level) + opts_.level = level_parser.parse(value); + else if (arg_name == print_exprs) + opts_.print_exprs = bool_parser.parse(value); + else + throw detail::ErrorArgNameInvalid{arg_name.data()}; } } // namespace detail diff --git a/tests/integration/eval/options.hpp b/tests/integration/eval/options.hpp index c2e021e023..a03baf1dde 100644 --- a/tests/integration/eval/options.hpp +++ b/tests/integration/eval/options.hpp @@ -25,6 +25,11 @@ struct OptionsOptimization { bool single_term = true; bool reuse_imeds = true; bool cache_leaves = true; + /// Multi-term factorization (\c A*B + A*C -> A*(B + C)). It is applied as + /// part of the optimize() call and therefore only takes effect when \c + /// single_term is also on; with \c single_term off, optimize() is not called + /// and this flag has no effect. + bool multi_term = false; }; struct OptionsSCF { @@ -41,6 +46,10 @@ struct OptionsLog { inline static constexpr size_t MAX_LEVEL = 1; size_t level = 1; std::string file = ""; + /// Whether to print the optimized SeQuant expressions (one summand per line, + /// in the deserialize/serialization text format) once, after optimization and + /// before any evaluation. + bool print_exprs = false; }; namespace detail { @@ -146,6 +155,8 @@ class ParseOptionsOptimization { inline static std::string_view const cache_leaves{"cache_leaves"}; + inline static std::string_view const multi_term{"multi_term"}; + OptionsOptimization opts_; public: @@ -186,8 +197,12 @@ class ParseOptionsLog { inline static ArgValInt const level_parser{OptionsLog::MIN_LEVEL, OptionsLog::MAX_LEVEL}; + inline static const auto bool_parser = ArgValBool{}; + inline static std::string_view const level{"level"}; + inline static std::string_view const print_exprs{"print_exprs"}; + OptionsLog opts_{}; public: diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 3411d2ee0e..741b3ede05 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -55,7 +55,7 @@ target_compile_definitions(unit_tests-sequant-eval-obj PRIVATE ########################## set(optimize_test_sources "test_extract_subtrees.cpp" - "test_fusion.cpp" + "test_multiterm.cpp" "test_optimize.cpp" ) add_library(unit_tests-sequant-optimize-obj OBJECT ${optimize_test_sources}) diff --git a/tests/unit/test_fusion.cpp b/tests/unit/test_fusion.cpp deleted file mode 100644 index 7ff4f70ab6..0000000000 --- a/tests/unit/test_fusion.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include - -#include "catch2_sequant.hpp" - -#include -#include -#include - -#include -#include -#include -#include - -TEST_CASE("fusion", "[optimize]") { - using sequant::opt::Fusion; - using namespace sequant; - std::vector> fused_terms{ - { - L"1/2 f{i3;i1} t{a1,a2;i2,i3}", // lhs - L"1/2 f{i3;a3} t{a3;i1} t{a1,a2;i2,i3}", // rhs - L"1/2(f{i3;i1} + f{i3;a3} t{a3;i1}) t{a1,a2;i2,i3}" // fused form - }, - - {L"1/8 g{a1,a2;a3,a4} t{a3,a4;i1,i2}", - L"1/4 g{a1,a2;a3,a4} t{a3;i1} t{a4;i2}", - L"1/8 g{a1,a2;a3,a4}(t{a3,a4;i1,i2} + 2 t{a3;i1} t{a4;i2})"}, - - {L"1/4 g{a1,a2;a3,a4} t{a3;i1} t{a4;i2}", - L"1/4 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3;i1} t{a4;i2}", - L"1/4(g{a1,a2;a3,a4} + g{i3,i4;a3,a4} t{a1;i3} t{a2;i4}) t{a3;i1} " - L"t{a4;i2}"}, - - {L"1/8 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3,a4;i1,i2}", - L"1/4 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3;i1} t{a4;i2}", - L"1/8 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} " - L" (t{a3,a4;i1,i2} + 2 t{a3;i1} t{a4;i2})"}, - - {L"-1/8 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3,a4;i1,i2}", - L"-1/4 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3;i1} t{a4;i2}", - L"-1/8 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} " - L" (t{a3,a4;i1,i2} + 2 t{a3;i1} t{a4;i2})"} - - }; - - for (auto&& [l, r, f] : fused_terms) { - auto const le = deserialize(l); - auto const re = deserialize(r); - auto const fe = deserialize(f); - auto fu = Fusion{le->as(), re->as()}; - REQUIRE((fu.left() || fu.right())); - - auto const& fue = fu.left() ? fu.left() : fu.right(); - - REQUIRE(fe == fue); - } -} diff --git a/tests/unit/test_multiterm.cpp b/tests/unit/test_multiterm.cpp new file mode 100644 index 0000000000..492abae9fe --- /dev/null +++ b/tests/unit/test_multiterm.cpp @@ -0,0 +1,560 @@ +#include + +#include "catch2_sequant.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +sequant::ExprPtr parse_antisymm(std::wstring_view s) { + using namespace sequant; + return deserialize(s, {.def_perm_symm = Symmetry::Antisymm}); +} + +sequant::ExprPtr parse_nonsymm(std::wstring_view s) { + using namespace sequant; + return deserialize(s, {.def_perm_symm = Symmetry::Nonsymm}); +} + +std::wstring latex(sequant::ExprPtr const& e) { return sequant::to_latex(e); } + +/// Every factored-group summand of \p sum, in summand order. A factored group +/// is a Product whose factors include a Sum -- i.e. shared * (partner1 + +/// partner2 + +/// ...). +std::vector find_all_factored(sequant::ExprPtr const& sum) { + using namespace sequant; + std::vector out; + if (!sum->is()) return out; + for (auto const& s : sum->as().summands()) { + if (!s->is()) continue; + for (auto const& f : s->as().factors()) + if (f->is()) { + out.push_back(s); + break; + } + } + return out; +} + +/// The first factored-group summand of \p sum (see \ref find_all_factored), or +/// null if none. +sequant::ExprPtr find_factored(sequant::ExprPtr const& sum) { + auto const all = find_all_factored(sum); + return all.empty() ? sequant::ExprPtr{} : all.front(); +} + +} // namespace + +TEST_CASE("multiterm factorization", "[multiterm]") { + using namespace sequant; + + // Sized spaces so the cost model is well-defined. Set once here, in the + // outermost scope, rather than per section; the cost-driven-winner section, + // whose expected fold depends on the occ-vs-virt ordering, overrides them + // locally via set_sizes. + auto ctx_resetter = set_scoped_default_context(get_default_context().clone()); + auto reg = get_default_context().mutable_index_space_registry(); + auto set_sizes = [®](std::size_t occ, std::size_t virt) { + reg->retrieve_ptr(L"i")->approximate_size(occ); + reg->retrieve_ptr(L"a")->approximate_size(virt); + }; + set_sizes(/*occ=*/10, /*virt=*/20); + + // Multi-term factorization enabled; NoReorder keeps the fold structure + // observable. The DenseSize and reorder sections build their own options. + OptimizeOptions const mt{.reorder = ReorderSum::NoReorder, + .multiterm = MultiTermFactor::Enable}; + + SECTION("disabled is the default and matches an explicit Disable") { + auto const expr = parse_antisymm( + L"G{i1,i2;a1,a2} T{a1,a2;i1,i2}" + L" + G{i1,i2;a1,a2} Z{a1,a2;i1,i2}"); + REQUIRE(*optimize(expr, {.multiterm = MultiTermFactor::Disable}) == + *optimize(expr)); + } + + SECTION("enabling the flag does not throw and yields a Sum") { + auto const expr = parse_antisymm( + L"G{i1,i2;a1,a2} T{a1,a2;i1,i2}" + L" + G{i1,i2;a1,a2} Z{a1,a2;i1,i2}"); + ExprPtr res; + REQUIRE_NOTHROW(res = optimize(expr, mt)); + REQUIRE(res); + REQUIRE(res->is()); + } + + SECTION("one-sided fold: A*B + A*C -> A*(B + C) when it lowers cost") { + // V contracted with T (or B) over a3,a4 leaves a tensor R{a1,a2;i1,i2}; the + // contraction (O(a^4 i^2)) dwarfs the partner-sum (O(a^2 i^2)), so folding + // pays off. + auto const expr = parse_antisymm( + L"V{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" + V{a1,a2;a3,a4} B{a3,a4;i1,i2}"); + auto const res = optimize(expr, mt); + // EquivalentTo expands both sides, so this is the round-trip guard *and* + // pins the exact content (the shared V, the partners, their indices); the + // factored RHS doubles as documentation of the intended output. + REQUIRE_THAT(res, + EquivalentTo(parse_antisymm( + L"V{a1,a2;a3,a4} (T{a3,a4;i1,i2} + B{a3,a4;i1,i2})"))); + // Structural: a fold actually happened (EquivalentTo cannot see this -- it + // expands the factored form away). + REQUIRE(res->size() == 1); + REQUIRE(find_factored(res)); + } + + SECTION("N-ary one-sided fold: A*B + A*C + A*D -> A*(B + C + D)") { + auto const expr = parse_antisymm( + L"V{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" + V{a1,a2;a3,a4} B{a3,a4;i1,i2}" + L" + V{a1,a2;a3,a4} U{a3,a4;i1,i2}"); + auto const res = optimize(expr, mt); + REQUIRE_THAT(res, EquivalentTo(parse_antisymm( + L"V{a1,a2;a3,a4} (T{a3,a4;i1,i2}" + L" + B{a3,a4;i1,i2} + U{a3,a4;i1,i2})"))); + REQUIRE(res->size() == 1); + REQUIRE(find_factored(res)); + } + + SECTION("transposed partners fold: V*t + V*t^T -> V*(t + t^T)") { + // After spin-tracing CC equations a shared factor multiplies two + // index-transposed copies of a (now non-antisymmetric) amplitude. They are + // two distinct partner vertices on one side of the biclique, so they fold + // like any other one-sided pair. (Antisymmetric t would instead collapse + // the two copies into one term up to sign, so this uses non-antisymmetric + // tensors -- the realistic post-spintracing case.) + auto const expr = parse_nonsymm( + L"V{i1,i2;a1,a2} t{a1,a2;i3,i4}" + L" + V{i1,i2;a1,a2} t{a2,a1;i4,i3}"); + auto const res = optimize(expr, mt); + REQUIRE_THAT(res, + EquivalentTo(parse_nonsymm( + L"V{i1,i2;a1,a2} (t{a1,a2;i3,i4} + t{a2,a1;i4,i3})"))); + REQUIRE(res->size() == 1); + REQUIRE(find_factored(res)); + } + + SECTION("a bare-tensor summand passes through while its neighbors fold") { + // V*T + V*B fold to V*(T + B); the bare tensor D{a1,a2;i1,i2} is a leaf (no + // contraction to split), so extract_core() returns nullopt and it is never + // interned, scored, or consumed. Reassembly emits untouched summands first + // (original order), then the folds: + // D + V*(T + B) + auto const expr = parse_antisymm( + L"V{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" + V{a1,a2;a3,a4} B{a3,a4;i1,i2}" + L" + D{a1,a2;i1,i2}"); + auto const res = optimize(expr, mt); + REQUIRE_THAT(res, + EquivalentTo(parse_antisymm( + L"D{a1,a2;i1,i2}" + L" + V{a1,a2;a3,a4} (T{a3,a4;i1,i2} + B{a3,a4;i1,i2})"))); + REQUIRE(res->size() == 2); // the untouched bare tensor + one fold + REQUIRE(find_factored(res)); + // D survived verbatim and, with NoReorder, precedes the fold. + auto const& smands = res->as().summands(); + REQUIRE(smands.front()->is()); + REQUIRE(smands.front()->as().label() == L"D"); + } + + SECTION("two-sided biclique: AX + AY + BX + BY -> (A + B)*(X + Y)") { + // U,V share a shape (left factors); T,B share a shape (right factors). The + // four contractions form a complete 2x2 bipartite graph and fold to + // (U + V)*(T + B). + auto const expr = parse_antisymm( + L"U{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" + U{a1,a2;a3,a4} B{a3,a4;i1,i2}" + L" + V{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" + V{a1,a2;a3,a4} B{a3,a4;i1,i2}"); + auto const res = optimize(expr, mt); + REQUIRE_THAT(res, EquivalentTo(parse_antisymm( + L"(U{a1,a2;a3,a4} + V{a1,a2;a3,a4})" + L" (T{a3,a4;i1,i2} + B{a3,a4;i1,i2})"))); + REQUIRE(res->size() == 1); // all four folded into one product + auto const factored = find_factored(res); + REQUIRE(factored); + // Both factors are 2-term sums: (U + V) and (T + B). + for (auto const& f : factored->as().factors()) { + REQUIRE(f->is()); + REQUIRE(f->size() == 2); + } + } + + SECTION("incomplete graph: cost-driven winner flips with extents") { + // Topology: edges (U,T),(U,B),(V,T) -- no (V,B). Two maximal one-sided + // folds compete with identical avoided contraction cost C: + // {U}x{T,B} -> U*(T + B), build cost ~ size(T) = a^2 i^2 + // {U,V}x{T} -> (U + V)*T, build cost ~ size(U) = a^4 + // Greedy keeps the higher-saving fold -- the one whose summed side is + // cheaper to build: it sums the lighter factors and shares the heavier one. + // Which factor is heavier is set purely by the occ-vs-virt extent ordering, + // so the same expression under both orderings must flip the shared factor + // (U <-> T). The flip shows the choice is cost-driven; the round-trip is + // the cost-model-independent guard. + std::wstring const expr_text = + L"U{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" + U{a1,a2;a3,a4} B{a3,a4;i1,i2}" + L" + V{a1,a2;a3,a4} T{a3,a4;i1,i2}"; + + // The shared (non-summed) factor's tensor label, with the structural guards + // that hold regardless of which side wins: exactly one fold over a 2-term + // partner sum, plus one untouched leftover summand. + auto winning_shared_label = [](ExprPtr const& res) -> std::wstring { + REQUIRE(res->is()); + REQUIRE(res->size() == 2); // one fold + one untouched leftover + auto const folded = find_all_factored(res); + REQUIRE(folded.size() == 1); + ExprPtr shared, partner_sum; + for (auto const& f : folded.front()->as().factors()) + (f->is() ? partner_sum : shared) = f; + REQUIRE(partner_sum); + REQUIRE(partner_sum->size() == 2); + REQUIRE(shared); + REQUIRE(shared->is()); + return std::wstring{shared->as().label()}; + }; + + // virt-heavy (a=20 > i=10): size(U)=a^4 dominates size(T)=a^2 i^2, so U is + // the heavier factor and is shared -- summing the lighter T, B into + // U*(T + B). + set_sizes(/*occ=*/10, /*virt=*/20); + auto const expr_vh = parse_antisymm(expr_text); + auto const res_vh = optimize(expr_vh, mt); + REQUIRE(winning_shared_label(res_vh) == L"U"); + REQUIRE_THAT(res_vh, EquivalentTo(expr_vh)); + + // occ-heavy (i=30 > a=10): now size(T)=a^2 i^2 dominates size(U)=a^4, so T + // is the heavier factor and is shared -- summing the lighter U, V into + // (U + V)*T. + set_sizes(/*occ=*/30, /*virt=*/10); + auto const expr_oh = parse_antisymm(expr_text); + auto const res_oh = optimize(expr_oh, mt); + REQUIRE(winning_shared_label(res_oh) == L"T"); + REQUIRE_THAT(res_oh, EquivalentTo(expr_oh)); + } + + SECTION("full contraction saves nothing: shareable but no fold") { + // G*T + G*Z both fully contract to a scalar -- exactly the saving()==0 + // boundary, and structurally so (not by extent tuning): for a full + // contraction the avoided cost C equals the product of the whole index set, + // and the only build cost is the partner sum (T + Z), whose footprint + // size(T) is that same whole index set. With the one-sided + // saving = (m*n - 1)*C - (n - 1)*size(T) at m=1, n=2 this is C - size(T) = + // 0 for any extents, so nothing folds. The cancellation is exact only under + // the current (m*n - 1)/(n - 1) coefficients (BicliqueCostModel::saving, + // multiterm.cpp); revisit this no-fold check if those change. The + // partial-contraction sibling below brackets this boundary from above. + auto const expr = parse_antisymm( + L"G{i1,i2;a1,a2} T{a1,a2;i1,i2}" + L" + G{i1,i2;a1,a2} Z{a1,a2;i1,i2}"); + auto const res = optimize(expr, mt); + REQUIRE(res->is()); + REQUIRE(res->size() == 2); + REQUIRE_FALSE(find_factored(res)); + } + + SECTION("partial contraction clears the threshold: V*(T + Z)") { + // Contrast with the scalar case: V shares with T and Z, but the contraction + // is partial (over a3,a4 only), so the result keeps free indices and the + // avoided cost C = O(a^4 i^2) strictly exceeds the partner-build + // size(T) = O(a^2 i^2). saving = C - size(T) > 0, so V*(T + Z) folds. With + // the scalar case sitting exactly on saving()==0, this brackets that + // boundary from the positive side. + auto const expr = parse_antisymm( + L"V{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" + V{a1,a2;a3,a4} Z{a3,a4;i1,i2}"); + auto const res = optimize(expr, mt); + REQUIRE_THAT(res, + EquivalentTo(parse_antisymm( + L"V{a1,a2;a3,a4} (T{a3,a4;i1,i2} + Z{a3,a4;i1,i2})"))); + REQUIRE(res->size() == 1); // V * (T + Z) + REQUIRE(find_factored(res)); + } + + SECTION("DenseSize objective: the fold still pays") { + // Unlike the other sections (default DenseFLOPs metric), this one sets + // objective_function to DenseSize, routing + // BicliqueCostModel::contraction_cost through memsize_counter instead of + // flops_counter. V*T + V*B still folds: + // under DenseSize the avoided contraction's element footprint (O(a^4), the + // V leg dominating) far exceeds the cost of building the (T + B) partner + // sum. + auto const expr = parse_antisymm( + L"V{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" + V{a1,a2;a3,a4} B{a3,a4;i1,i2}"); + auto const res = + optimize(expr, {.objective_function = ObjectiveFunction::DenseSize, + .reorder = ReorderSum::NoReorder, + .multiterm = MultiTermFactor::Enable}); + REQUIRE_THAT(res, + EquivalentTo(parse_antisymm( + L"V{a1,a2;a3,a4} (T{a3,a4;i1,i2} + B{a3,a4;i1,i2})"))); + REQUIRE(res->size() == 1); // V * (T + B) + REQUIRE(find_factored(res)); + } + + SECTION("honors reorder independently of multiterm") { + // reorder and multiterm are independent options, so enabling multiterm must + // not disable reorder. The input keeps the two effects separable: its two + // g*t summands fully contract to scalars, so folding saves nothing (the + // saving()==0 boundary) and the factorizer leaves all three summands alone. + // reorder, on the other hand, does move them. So latex -- which (unlike + // EquivalentTo, which canonicalizes summand order away) is order-sensitive + // -- is the right comparison here. + auto const expr = parse_antisymm( + L"g{i3,i4;a1,a2} t{a1,a2;i5,i6} A{i5,i6;i3,i4}" + L" + p{i1;i2} q{i2;i1}" + L" + g{i3,i4;a1,a2} t{a1,a2;i5,i6} B{i5,i6;i3,i4}"); + + auto const mt_noreorder = + optimize(expr, {.reorder = ReorderSum::NoReorder, + .multiterm = MultiTermFactor::Enable}); + auto const mt_reorder = optimize( + expr, + {.reorder = ReorderSum::Reorder, .multiterm = MultiTermFactor::Enable}); + + // Nothing folds (full contraction, saving 0), so all three summands survive + // and reorder is the only thing that can change the output. + REQUIRE(mt_noreorder->is()); + REQUIRE(mt_noreorder->size() == 3); + REQUIRE_FALSE(find_factored(mt_noreorder)); + + // With the bug, reorder got skipped under multiterm and these two came out + // identical; fixed, they differ. + REQUIRE(latex(mt_noreorder) != latex(mt_reorder)); + + // And the order matches plain reorder: with nothing to fold, multiterm + // leaves reorder's result untouched. + auto const reorder_only = + optimize(expr, {.reorder = ReorderSum::Reorder, + .multiterm = MultiTermFactor::Disable}); + REQUIRE(latex(mt_reorder) == latex(reorder_only)); + } + + SECTION("reorder operates on a folded Product(Sum) summand") { + // The section above folds nothing, so reorder never re-binarizes a + // Product(Sum). Here a fold and a reorder happen together: + // V*T + V*B -> V*(T + B) partial contraction, saving > 0 + // g*t*A, p*q, g*t*C full contraction, saving 0 (no fold) + // g*t*A and g*t*C share g*t, split by p*q, so reorder has something to + // cluster. + auto const expr = parse_antisymm( + L"V{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" + V{a1,a2;a3,a4} B{a3,a4;i1,i2}" + L" + g{i3,i4;a5,a6} t{a5,a6;i5,i6} A{i5,i6;i3,i4}" + L" + p{i7;i8} q{i8;i7}" + L" + g{i3,i4;a5,a6} t{a5,a6;i5,i6} C{i5,i6;i3,i4}"); + + auto const noreorder = + optimize(expr, {.reorder = ReorderSum::NoReorder, + .multiterm = MultiTermFactor::Enable}); + auto const reordered = optimize( + expr, + {.reorder = ReorderSum::Reorder, .multiterm = MultiTermFactor::Enable}); + + // V*(T + B) + g*t*A + p*q + g*t*C: four summands either way. + REQUIRE(noreorder->is()); + REQUIRE(noreorder->size() == 4); + REQUIRE(reordered->is()); + REQUIRE(reordered->size() == 4); + + // V*(T + B) survives reorder's clusters()/binarize round-trip on the + // Product(Sum): present in both, structurally identical (reorder permutes, + // never rewrites). + auto const folded_noreorder = find_factored(noreorder); + auto const folded_reordered = find_factored(reordered); + REQUIRE(folded_noreorder); + REQUIRE(folded_reordered); + REQUIRE(*folded_reordered == *folded_noreorder); + + // reorder clustered g*t*A next to g*t*C: same summands, different order, + // so unequal. + REQUIRE_FALSE(*noreorder == *reordered); + } + + SECTION("distinct contraction signatures fold independently, not merged") { + // All four summands share the external indices {a1,a2,i1,i2}. The first two + // contract over a virtual pair (a3,a4); the last two over an occupied pair + // (i3,i4). Distinct signatures => two independent buckets => two separate + // folds V*(T + B) and W*(X + Y), never one cross-bucket merge. + auto const expr = parse_antisymm( + L"V{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" + V{a1,a2;a3,a4} B{a3,a4;i1,i2}" + L" + W{a1,a2;i3,i4} X{i3,i4;i1,i2}" + L" + W{a1,a2;i3,i4} Y{i3,i4;i1,i2}"); + auto const res = optimize(expr, mt); + REQUIRE_THAT(res, + EquivalentTo(parse_antisymm( + L"V{a1,a2;a3,a4} (T{a3,a4;i1,i2} + B{a3,a4;i1,i2})" + L" + W{a1,a2;i3,i4} (X{i3,i4;i1,i2} + Y{i3,i4;i1,i2})"))); + REQUIRE(res->size() == 2); // one fold per bucket + REQUIRE(find_all_factored(res).size() == 2); + } + + SECTION("sign fold: A*B - A*C -> A*(B - C)") { + // The factorizer peels the (-1) prefactor off the summand and folds the + // relative sign into the partner. + auto const expr = parse_antisymm( + L"V{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" - V{a1,a2;a3,a4} B{a3,a4;i1,i2}"); + auto const res = optimize(expr, mt); + REQUIRE_THAT(res, + EquivalentTo(parse_antisymm( + L"V{a1,a2;a3,a4} (T{a3,a4;i1,i2} - B{a3,a4;i1,i2})"))); + REQUIRE(res->size() == 1); // V * (T - B) + REQUIRE(find_factored(res)); + } + + SECTION("rational prefactors: 2*A*B + 3*A*C -> A*(2B + 3C)") { + auto const expr = parse_antisymm( + L"2 V{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" + 3 V{a1,a2;a3,a4} B{a3,a4;i1,i2}"); + auto const res = optimize(expr, mt); + REQUIRE_THAT(res, + EquivalentTo(parse_antisymm( + L"V{a1,a2;a3,a4} (2 T{a3,a4;i1,i2} + 3 B{a3,a4;i1,i2})"))); + REQUIRE(res->size() == 1); + REQUIRE(find_factored(res)); + } + + SECTION("rank-1 two-sided sign matrix folds as (U - V)*(T + B)") { + // U*T + U*B - V*T - V*B. The sign matrix [[+,+],[-,-]] is multiplicative + // rank 1, so the whole 2x2 still folds to a single product, the sign riding + // the left factor-sum. + auto const expr = parse_antisymm( + L"U{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" + U{a1,a2;a3,a4} B{a3,a4;i1,i2}" + L" - V{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" - V{a1,a2;a3,a4} B{a3,a4;i1,i2}"); + auto const res = optimize(expr, mt); + REQUIRE_THAT(res, EquivalentTo(parse_antisymm( + L"(U{a1,a2;a3,a4} - V{a1,a2;a3,a4})" + L" (T{a3,a4;i1,i2} + B{a3,a4;i1,i2})"))); + REQUIRE(res->size() == 1); // single (U - V)*(T + B) + auto const factored = find_factored(res); + REQUIRE(factored); + for (auto const& f : factored->as().factors()) { + REQUIRE(f->is()); + REQUIRE(f->size() == 2); + } + } + + SECTION("non-rank-1 sign matrix reduces to one-sided folds") { + // U*T + U*B + V*T - V*B. The sign matrix [[+,+],[+,-]] does not factor, so + // a single (U +/- V)*(T +/- B) would be wrong. The factorizer must decline + // the 2x2 and reduce to two one-sided folds. The reduction is not unique + // (it may share either side), so EquivalentTo on the original sum is the + // correctness guard while the structural checks pin "two one-sided folds, + // no bogus 2x2". + auto const expr = parse_antisymm( + L"U{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" + U{a1,a2;a3,a4} B{a3,a4;i1,i2}" + L" + V{a1,a2;a3,a4} T{a3,a4;i1,i2}" + L" - V{a1,a2;a3,a4} B{a3,a4;i1,i2}"); + auto const res = optimize(expr, mt); + REQUIRE_THAT(res, EquivalentTo(expr)); + REQUIRE(res->size() == 2); // two one-sided folds + REQUIRE(find_all_factored(res).size() == 2); + for (auto const& f : find_all_factored(res)) { + ExprPtr partner_sum; + for (auto const& g : f->as().factors()) + if (g->is()) partner_sum = g; + REQUIRE(partner_sum); + REQUIRE(partner_sum->size() == 2); + } + } + + SECTION("canonicalization-phase fold lands a relative sign on the partner") { + // Distinct from the scalar-prefactor folds above: there the relative sign + // was an explicit input coefficient that extract_core peels off. Here + // neither summand carries any sign or scalar -- every relative sign comes + // purely from canonicalization. Both summands share the intermediate + // M = G*W, contracted over the virtual pair a3,a4 (so M is a child factor; + // only an intermediate can carry a non-trivial canon_phase). The second + // summand writes G's contracted pair swapped (a4,a3); antisymmetric G makes + // M canonicalize to the same indices but opposite phase. Phase-relaxed + // matching interns both Ms into one vertex; the relative -1 rides the edge + // and lands on the partner: M*(X - Y). size()==1 is the proof the + // relaxation engaged (phase-strict would leave size()==2). + auto const expr = parse_antisymm( + L"G{i1,i2;a3,a4} W{a3,a4;i3,i4} X{i3,i4;a1,a2}" + L" + G{i1,i2;a4,a3} W{a3,a4;i3,i4} Y{i3,i4;a1,a2}"); + auto const res = optimize(expr, mt); + REQUIRE(res->is()); + REQUIRE(res->size() == 1); // folded => phase relaxation engaged + + auto const factored = find_factored(res); + REQUIRE(factored); + ExprPtr shared, partner_sum; + for (auto const& f : factored->as().factors()) + (f->is() ? partner_sum : shared) = f; + // The shared factor is the G*W intermediate (a Product), not a bare leaf. + REQUIRE(shared); + REQUIRE(shared->is()); + REQUIRE(partner_sum); + REQUIRE(partner_sum->size() == 2); + + // The relative canonicalization sign landed on exactly one partner, even + // though neither input summand carried a sign or scalar -- the (X - Y) + // shape. side_expr() emits a +1 partner as a bare Tensor and a non-unit one + // as a scalar Product, so exactly one partner is a negative Product. + std::size_t negative_partners = 0; + for (auto const& g : partner_sum->as().summands()) + if (g->is() && g->as().scalar().real() < 0) + ++negative_partners; + REQUIRE(negative_partners == 1); + + // Round-trip is the cost-model-independent correctness guard. + REQUIRE_THAT(res, EquivalentTo(expr)); + } + + SECTION( + "three-factor input yields a single 2-sided fold, not nested factors") { + // (A + B)(C + D)(E + F) fully expanded into 8 terms. The factorizer only + // emits 2-sided bicliques and never re-factors the partner sums, so it + // produces a single fold (AC + AD + BC + BD)*(E + F) -- NOT the nested + // 3-factor form. This locks the documented scope limit (see multiterm.hpp). + // Non-antisymmetric tensors so the eight terms stay independent. + auto const expr = parse_nonsymm( + L"A{i1,i2;a1,a2} C{a1,a2;a3,a4} E{a3,a4;i3,i4}" + L" + A{i1,i2;a1,a2} D{a1,a2;a3,a4} E{a3,a4;i3,i4}" + L" + B{i1,i2;a1,a2} C{a1,a2;a3,a4} E{a3,a4;i3,i4}" + L" + B{i1,i2;a1,a2} D{a1,a2;a3,a4} E{a3,a4;i3,i4}" + L" + A{i1,i2;a1,a2} C{a1,a2;a3,a4} F{a3,a4;i3,i4}" + L" + A{i1,i2;a1,a2} D{a1,a2;a3,a4} F{a3,a4;i3,i4}" + L" + B{i1,i2;a1,a2} C{a1,a2;a3,a4} F{a3,a4;i3,i4}" + L" + B{i1,i2;a1,a2} D{a1,a2;a3,a4} F{a3,a4;i3,i4}"); + auto const res = optimize(expr, mt); + REQUIRE_THAT(res, EquivalentTo(expr)); // correctness guard + REQUIRE(res->size() == 1); + auto const factored = find_factored(res); + REQUIRE(factored); + + // Exactly the two sides of one biclique: a 4-term left sum (still + // unfactored) and the 2-term (E + F). Both factors are Sums. + std::vector sum_sizes; + for (auto const& f : factored->as().factors()) { + REQUIRE(f->is()); + sum_sizes.push_back(f->size()); + // No partner is itself factored: no nested 3-factor form. + for (auto const& g : f->as().summands()) + if (g->is()) + for (auto const& h : g->as().factors()) + REQUIRE_FALSE(h->is()); + } + std::sort(sum_sizes.begin(), sum_sizes.end()); + REQUIRE(sum_sizes == std::vector{2, 4}); + } +}