Maestro 0.1.0
Unified interface for quantum circuit simulation
Loading...
Searching...
No Matches
ForestContractor.h
Go to the documentation of this file.
1
18
19#pragma once
20
21#ifndef __FOREST_CONTRACTOR_H_
22#define __FOREST_CONTRACTOR_H_ 1
23
24#include "BaseContractor.h"
25
26#include <boost/container_hash/hash.hpp>
27
28namespace TensorNetworks {
29
36public:
43 double Contract(const TensorNetwork &network, Types::qubit_t qubit) override {
44 std::vector<Eigen::Index> keys;
45 std::unordered_map<Eigen::Index, Eigen::Index> keysKeys;
46
47 TensorsMap tensors =
48 InitializeTensors(network, qubit, keys, keysKeys, false);
49 std::map<Eigen::Index, std::shared_ptr<TensorNode>> tensorsMap(
50 tensors.begin(), tensors.end());
51
52 using TensorPair = std::pair<Eigen::Index, Eigen::Index>;
53 std::unordered_set<TensorPair, boost::hash<TensorPair>> visitedPairs;
54
55 Eigen::Index maxRank = 4;
56 Eigen::Index nextRank = 4;
57
58 // the first two initialized just to keep the compiler happy
59 Eigen::Index tensor1Id = 0;
60 Eigen::Index tensor2Id = 1;
61 Eigen::Index resultRank = 4;
62 Eigen::Index bestCost = 4;
63
64 // while there is more than one tensor...
65 while (tensors.size() > 1) {
66 // by going like this we pick up the tensors that are closer to the end of
67 // the circuit, to contract with the 'super' ones can be very efficient
68 // for forest kind of circuits
69
70 bool found = false;
71 bool nextRankSet = false;
72
73 for (auto tensorIt = tensorsMap.rbegin(); tensorIt != tensorsMap.rend();
74 ++tensorIt) {
75 const auto &tensor = tensorIt->second;
76 const auto curTensorId = tensor->GetId();
77
78 // despite having a bigger rank, the contraction could lead to a smaller
79 // one to match the rank limit, so don't do this
80 // if (tensor->GetRank() > maxRank) continue;
81
82 for (Eigen::Index ti = 0;
83 ti < static_cast<Eigen::Index>(tensor->connections.size()); ++ti) {
84 const auto nextTensorId = tensor->connections[ti];
85
86 if (nextTensorId != TensorNode::NotConnected) {
87 auto t1 = curTensorId;
88 auto t2 = nextTensorId;
89 if (t1 > t2)
90 std::swap(t1, t2);
91
92 const auto p = std::make_pair(t1, t2);
93 if (visitedPairs.find(p) != visitedPairs.end())
94 continue;
95 else
96 visitedPairs.insert(p);
97
98 const Eigen::Index newRank =
99 GetResultRank(tensor, tensors[nextTensorId]);
100
101 if (newRank <= maxRank) {
102 const Eigen::Index Cost =
103 newRank -
104 (tensor->GetRank() + tensors[nextTensorId]->GetRank()) / 2;
105
106 if (!found || newRank < resultRank ||
107 (newRank == resultRank && Cost < bestCost)) {
108 tensor1Id = curTensorId;
109 tensor2Id = nextTensorId;
110 resultRank = newRank;
111 bestCost = Cost;
112 found = true;
113 }
114 }
115
116 if (!found) {
117 if (!nextRankSet) {
118 nextRankSet = true;
119 nextRank = newRank;
120 } else
121 nextRank = std::min(nextRank, newRank);
122 }
123 }
124 }
125 }
126
127 visitedPairs.clear();
128
129 if (!found) {
130 maxRank = nextRank;
131 continue;
132 }
133
134 ContractNodes(qubit, tensors, tensor1Id, tensor2Id, resultRank);
135 tensorsMap[tensor1Id] = tensors[tensor1Id];
136 tensorsMap.erase(tensor2Id);
137
138 if (resultRank == 0) {
139 if (tensors.size() == 1 || tensors[tensor1Id]->contractsTheNeededQubit)
140 return std::real(tensors[tensor1Id]->tensor->atOffset(0));
141
142 // erasing this tensor happens because (not the case anymore, it's
143 // avoided) the tensor network might be a disjoint one and a subnetwork
144 // is contracted that does not contain the needed qubit
145
146 tensors.erase(tensor1Id);
147 tensorsMap.erase(tensor1Id);
148 }
149 }
150
151 return std::real(tensors.begin()->second->tensor->atOffset(0));
152 }
153
159 std::shared_ptr<ITensorContractor> Clone() const override {
160 auto cloned = std::make_shared<ForestContractor>();
161
162 cloned->maxTensorRank = maxTensorRank;
163 cloned->enableMultithreading = enableMultithreading;
164
165 return cloned;
166 }
167};
168
169} // namespace TensorNetworks
170
171#endif // __FOREST_CONTRACTOR_H_
The Base Class Tensor Contractor.
bool enableMultithreading
A flag to indicate if multithreading should be enabled.
ITensorContractor::TensorsMap TensorsMap
TensorsMap InitializeTensors(const TensorNetwork &network, Types::qubit_t qubit, std::vector< Eigen::Index > &keys, std::unordered_map< Eigen::Index, Eigen::Index > &keysKeys, bool fillKeys=true, bool contract=true) override
Eigen::Index ContractNodes(Types::qubit_t qubit, PassedTensorsMap &tensors, Eigen::Index tensor1Id, Eigen::Index tensor2Id, Eigen::Index resultRank)
size_t maxTensorRank
The maximum rank of the tensors in the network.
static size_t GetResultRank(const std::shared_ptr< TensorNode > &tensor1, const std::shared_ptr< TensorNode > &tensor2)
The Forest Tensor Contractor.
double Contract(const TensorNetwork &network, Types::qubit_t qubit) override
Contract the tensor network.
std::shared_ptr< ITensorContractor > Clone() const override
Clone the tensor contractor.
static constexpr Index NotConnected
Definition TensorNode.h:155
uint_fast64_t qubit_t
The type of a qubit.
Definition Types.h:20