Maestro 0.2.11
Unified interface for quantum circuit simulation
Loading...
Searching...
No Matches
MultivariateHermiteInterpolation.h
Go to the documentation of this file.
1
18
19#pragma once
20
21#ifndef __UTILS_MULTIVARIATE_HERMITE_INTERPOLATION_H__
22#define __UTILS_MULTIVARIATE_HERMITE_INTERPOLATION_H__
23
24#include <cassert>
25#include <memory>
26#include <vector>
27
29
30namespace Utils {
31
33 public:
34 // Samples are assumed to be sorted lexicographically by the coordinates in x.
35 // All x[i] must have the same dimension (>= 1).
36 void SetSamples(const std::vector<std::vector<double>>& x, const std::vector<double>& y)
37 {
38 assert(x.size() == y.size());
39 assert(!x.empty());
40
41 if (x.empty() || y.empty()) return;
42
43 assert(x[0].size() >= 2);
44
45 leafInterpolator.reset();
46 children.clear();
47 xValues.clear();
48
49 dimension = x[0].size();
50
51 if (dimension == 2)
52 {
53 // degenerate to the bivariate case
54 std::vector<std::vector<double>> xv;
55 xv.reserve(x.size());
56 for (size_t i = 0; i < x.size(); ++i)
57 xv.push_back(x[i]);
58
59 leafInterpolator = std::make_unique<BivariateHermiteInterpolation>();
60 leafInterpolator->SetTrueInterpolation(trueInterpolation);
61 leafInterpolator->SetSamples(xv, y);
62 return;
63 }
64
65 // Group samples by the first coordinate, then recurse on the rest.
66 std::vector<std::vector<double>> subX;
67 std::vector<double> subY;
68
69 auto flushGroup = [&]() {
70 if (subX.empty())
71 return;
72 children.emplace_back(std::make_unique<MultivariateHermiteInterpolation>());
73 children.back()->SetTrueInterpolation(trueInterpolation);
74 children.back()->SetSamples(subX, subY);
75 subX.clear();
76 subY.clear();
77 };
78
79 xValues.push_back(x[0][0]);
80 for (size_t i = 0; i < x.size(); ++i)
81 {
82 assert(x[i].size() == dimension);
83
84 if (x[i][0] != xValues.back())
85 {
86 flushGroup();
87 xValues.push_back(x[i][0]);
88 }
89
90 subX.emplace_back(x[i].begin() + 1, x[i].end());
91 subY.push_back(y[i]);
92 }
93 flushGroup();
94 }
95
96 double Predict(const std::vector<double>& x) const
97 {
98 assert(x.size() == dimension);
99
100 if (dimension < 2 || x.size() != dimension)
101 return 0;
102
103 if (leafInterpolator)
104 return leafInterpolator->Predict(x);
105
106 const std::vector<double> tail(x.begin() + 1, x.end());
107
108 std::vector<double> vals;
109 vals.reserve(children.size());
110 for (size_t i = 0; i < children.size(); ++i)
111 vals.push_back(children[i]->Predict(tail));
112
113 HermiteInterpolation firstInterpolator;
114 firstInterpolator.SetTrueInterpolation(trueInterpolation);
115 firstInterpolator.SetSamples(xValues, vals);
116
117 const double val = firstInterpolator.Predict(x[0]);
118
119 if (trueInterpolation)
120 return val;
121
122 const auto m = 1E-12;
123 if (val < m)
124 return m;
125
126 return val;
127 }
128
129 void SetTrueInterpolation(bool reg)
130 {
131 trueInterpolation = reg;
132
133 if (leafInterpolator)
134 leafInterpolator->SetTrueInterpolation(reg);
135
136 for (auto& child : children)
137 if (child)
138 child->SetTrueInterpolation(reg);
139 }
140
141 private:
142 // Used when dimension == 2.
143 std::unique_ptr<BivariateHermiteInterpolation> leafInterpolator;
144
145 // Used when dimension > 2: one child per unique value of the first coordinate.
146 std::vector<std::unique_ptr<MultivariateHermiteInterpolation>> children;
147 std::vector<double> xValues;
148
149 size_t dimension = 0;
150
151 bool trueInterpolation = false;
152 };
153
154}
155
156#endif // __UTILS_MULTIVARIATE_HERMITE_INTERPOLATION_H__
double Predict(double x) const
void SetSamples(const std::vector< T > &x, const std::vector< double > &y)
void SetSamples(const std::vector< std::vector< double > > &x, const std::vector< double > &y)
double Predict(const std::vector< double > &x) const
Definition Alias.h:22