G+Smo  24.08.0
Geometry + Simulation Modules
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
gsSpaceTimeFitter.h
Go to the documentation of this file.
1 
14 #pragma once
15 
16 namespace gismo
17 {
18 
19 template<size_t domainDim, class T>
20 class gsSpaceTimeFitter
21 {
22  typedef typename gsTensorBSpline<domainDim+1,real_t>::BoundaryGeometryType slice_t;
23 
24 
25 public:
26  gsSpaceTimeFitter ( const std::vector<gsMatrix<T>> & solutionCoefs,
27  const gsVector<T> & times,
28  const gsVector<T> & ptimes,
29  const gsMultiBasis<T> & spatialBasis,
30  const index_t deg = 2)
31  :
32  m_data(solutionCoefs),
33  m_times(times),
34  m_ptimes(ptimes),
35  m_bases(spatialBasis),
36  m_deg(deg)
37  {
38  GISMO_ENSURE(m_data.size() != 0,"No data provided!");
39  GISMO_ENSURE(m_data.size() == (size_t)m_times.size(),"Solution coefs and times should have the same size");
40  GISMO_ENSURE((size_t)m_ptimes.size() == (size_t)m_times.size(),"(Parametric)times should have the same size");
41  m_targetDim = m_data.at(0).cols();
42  }
43 
44  void setDegree(index_t deg) {m_deg = deg;}
45 
46  void addDataPoint(gsMatrix<T> & solution, T time, T ptime, index_t continuity )
47  {
48  index_t N = m_ptimes.rows();
49  m_ptimes.conservativeResize(N+1);
50  m_ptimes.row(N) = ptime;
51  gsDebugVar(m_ptimes);
52  std::sort(begin(m_ptimes),end(m_ptimes));
53  gsDebugVar(m_ptimes);
54  }
55 
56 protected:
57 
58  gsTensorBSplineBasis<domainDim+1,T> _basis(const gsKnotVector<T> &kv, index_t nsteps)
59  {
60  return _basis_impl<domainDim+1>(kv,nsteps);
61  }
62 
63  template<index_t _domainDim>
64  typename std::enable_if<_domainDim==2, gsTensorBSplineBasis<_domainDim,T>>::type
65  _basis_impl(const gsKnotVector<T> &kv, index_t nsteps)
66  {
67  gsTensorBSplineBasis<_domainDim,T> tbasis(
68  static_cast<gsBSplineBasis<T> *>(&m_bases.basis(0).component(1))->knots(),
69  kv
70  );
71  return tbasis;
72  }
73 
74  template<index_t _domainDim>
75  typename std::enable_if<_domainDim==3, gsTensorBSplineBasis<_domainDim,T>>::type
76  _basis_impl(const gsKnotVector<T> &kv, index_t nsteps)
77  {
78  gsTensorBSplineBasis<_domainDim,T> tbasis(
79  static_cast<gsBSplineBasis<T> *>(&m_bases.basis(0).component(0))->knots(),
80  static_cast<gsBSplineBasis<T> *>(&m_bases.basis(0).component(1))->knots(),
81  kv
82  );
83  return tbasis;
84  }
85 
86  template<index_t _domainDim>
87  typename std::enable_if<_domainDim==4, gsTensorBSplineBasis<_domainDim,T>>::type
88  _basis_impl(const gsKnotVector<T> &kv, index_t nsteps)
89  {
90  gsTensorBSplineBasis<_domainDim,T> tbasis(
91  static_cast<gsBSplineBasis<T> *>(&m_bases.basis(0).component(0))->knots(),
92  static_cast<gsBSplineBasis<T> *>(&m_bases.basis(0).component(1))->knots(),
93  static_cast<gsBSplineBasis<T> *>(&m_bases.basis(0).component(2))->knots(),
94  kv
95  );
96  return tbasis;
97  }
98 
99 public:
100 
101  std::pair<T,gsGeometry<T> *> slice(T xi)
102  {
103  slice_t res;
104  m_fit.slice(domainDim,xi,res);
105  gsGeometry<T> * geom = res.clone().release();
106  T load = geom->coefs()(0,domainDim+1);
107  geom->embed(m_targetDim);
108  return std::make_pair(load,geom);
109  }
110 
111  void compute()
112  {
113  GISMO_ASSERT(m_data.size()==(size_t)m_times.rows(),"Number of time and solution steps should match! "<<m_data.size()<<"!="<<m_times.rows());
114  // GISMO_ASSERT(m_data.at(0).cols() == dim,"Is the dimension correct?"<<m_data.at(0).cols()<<"!="<<dim);
115  index_t nsteps = m_times.rows();
116  index_t bsize = m_data.at(0).rows();
117 
118  // Prepare fitting basis
119  gsKnotVector<> kv(m_ptimes.minCoeff(),m_ptimes.maxCoeff(),nsteps-(m_deg+1),m_deg+1);
120  gsBSplineBasis<T> lbasis(kv);
121 
126 
127  m_basis = _basis(kv,nsteps);
128 
129  gsMatrix<T> rhs(m_times.size(),m_targetDim*bsize+1);
130  gsVector<T> ones; ones.setOnes(bsize);
131 
132  for (index_t lam = 0; lam!=nsteps; ++lam)
133  {
134  rhs.block(lam,0,1,m_targetDim * bsize) = m_data.at(lam).reshape(1,m_targetDim * bsize);
135  rhs(lam,m_targetDim*bsize) = m_times.at(lam);
136  }
137 
138  // get the Greville Abcissae (anchors)
139  gsMatrix<T> anchors = lbasis.anchors();
140 
141  // Get the collocation matrix at the anchors
142  gsSparseMatrix<T> C = lbasis.collocationMatrix(anchors);
143 
144  gsSparseSolver<>::LU solver;
145  solver.compute(C);
146 
147  gsMatrix<T> sol;
148  m_coefs.resize((nsteps)*bsize,m_targetDim+1);
149 
150  sol = solver.solve(rhs);
151 
152  for (index_t lam = 0; lam!=nsteps; ++lam)
153  {
154  gsMatrix<> tmp = sol.block(lam,0,1,m_targetDim * bsize);
155  m_coefs.block(lam * bsize,0,bsize,m_targetDim) = tmp.reshape(bsize,m_targetDim);
156  m_coefs.block(lam * bsize,m_targetDim,bsize,1) = sol(lam,m_targetDim*bsize) * ones;
157  }
158 
159  // gsTensorBSpline<3,T> tspline = tbasis.makeGeometry(give(coefs)).release();
160  // gsTensorBSpline<dim,T> tspline(m_basis,give(m_coefs));
161 
162  m_fit = gsTensorBSpline<domainDim+1,T>(m_basis,give(m_coefs));
163  }
164 
165 protected:
166  index_t m_targetDim;
167  std::vector<gsMatrix<T>> m_data;
168  gsVector<T> m_times;
169  gsVector<T> m_ptimes;
170  gsMultiBasis<T> m_bases;
171  index_t m_deg;
172 
173  mutable gsTensorBSplineBasis<domainDim+1,T> m_basis;
174  gsMatrix<T> m_coefs;
175 
176  mutable gsTensorBSpline<domainDim+1,T> m_fit;
177 
178 };
179 
180 
181 } // namespace gismo
182 
183 // #ifndef GISMO_BUILD_LIB
184 // #include GISMO_HPP_HEADER(gsSpaceTimeHierarchy.hpp)
185 // #endif
S give(S &x)
Definition: gsMemory.h:266
#define index_t
Definition: gsConfig.h:32
#define GISMO_ENSURE(cond, message)
Definition: gsDebug.h:102
#define GISMO_ASSERT(cond, message)
Definition: gsDebug.h:89