19 template<
size_t domainDim,
class T>
20 class gsSpaceTimeFitter
22 typedef typename gsTensorBSpline<domainDim+1,real_t>::BoundaryGeometryType slice_t;
26 gsSpaceTimeFitter (
const std::vector<gsMatrix<T>> & solutionCoefs,
27 const gsVector<T> & times,
28 const gsVector<T> & ptimes,
29 const gsMultiBasis<T> & spatialBasis,
32 m_data(solutionCoefs),
35 m_bases(spatialBasis),
39 GISMO_ENSURE(m_data.size() == (size_t)m_times.size(),
"Solution coefs and times should have the same size");
40 GISMO_ENSURE((
size_t)m_ptimes.size() == (size_t)m_times.size(),
"(Parametric)times should have the same size");
41 m_targetDim = m_data.at(0).cols();
44 void setDegree(
index_t deg) {m_deg = deg;}
46 void addDataPoint(gsMatrix<T> & solution, T time, T ptime,
index_t continuity )
49 m_ptimes.conservativeResize(N+1);
50 m_ptimes.row(N) = ptime;
52 std::sort(begin(m_ptimes),end(m_ptimes));
58 gsTensorBSplineBasis<domainDim+1,T> _basis(
const gsKnotVector<T> &kv,
index_t nsteps)
60 return _basis_impl<domainDim+1>(kv,nsteps);
63 template<index_t _domainDim>
64 typename std::enable_if<_domainDim==2, gsTensorBSplineBasis<_domainDim,T>>::type
65 _basis_impl(
const gsKnotVector<T> &kv,
index_t nsteps)
67 gsTensorBSplineBasis<_domainDim,T> tbasis(
68 static_cast<gsBSplineBasis<T> *
>(&m_bases.basis(0).component(1))->knots(),
74 template<index_t _domainDim>
75 typename std::enable_if<_domainDim==3, gsTensorBSplineBasis<_domainDim,T>>::type
76 _basis_impl(
const gsKnotVector<T> &kv,
index_t nsteps)
78 gsTensorBSplineBasis<_domainDim,T> tbasis(
79 static_cast<gsBSplineBasis<T> *
>(&m_bases.basis(0).component(0))->knots(),
80 static_cast<gsBSplineBasis<T> *
>(&m_bases.basis(0).component(1))->knots(),
86 template<index_t _domainDim>
87 typename std::enable_if<_domainDim==4, gsTensorBSplineBasis<_domainDim,T>>::type
88 _basis_impl(
const gsKnotVector<T> &kv,
index_t nsteps)
90 gsTensorBSplineBasis<_domainDim,T> tbasis(
91 static_cast<gsBSplineBasis<T> *
>(&m_bases.basis(0).component(0))->knots(),
92 static_cast<gsBSplineBasis<T> *
>(&m_bases.basis(0).component(1))->knots(),
93 static_cast<gsBSplineBasis<T> *
>(&m_bases.basis(0).component(2))->knots(),
101 std::pair<T,gsGeometry<T> *> slice(T xi)
104 m_fit.slice(domainDim,xi,res);
105 gsGeometry<T> * geom = res.clone().release();
106 T load = geom->coefs()(0,domainDim+1);
107 geom->embed(m_targetDim);
108 return std::make_pair(load,geom);
113 GISMO_ASSERT(m_data.size()==(size_t)m_times.rows(),
"Number of time and solution steps should match! "<<m_data.size()<<
"!="<<m_times.rows());
115 index_t nsteps = m_times.rows();
116 index_t bsize = m_data.at(0).rows();
119 gsKnotVector<> kv(m_ptimes.minCoeff(),m_ptimes.maxCoeff(),nsteps-(m_deg+1),m_deg+1);
120 gsBSplineBasis<T> lbasis(kv);
127 m_basis = _basis(kv,nsteps);
129 gsMatrix<T> rhs(m_times.size(),m_targetDim*bsize+1);
130 gsVector<T> ones; ones.setOnes(bsize);
132 for (
index_t lam = 0; lam!=nsteps; ++lam)
134 rhs.block(lam,0,1,m_targetDim * bsize) = m_data.at(lam).reshape(1,m_targetDim * bsize);
135 rhs(lam,m_targetDim*bsize) = m_times.at(lam);
139 gsMatrix<T> anchors = lbasis.anchors();
142 gsSparseMatrix<T> C = lbasis.collocationMatrix(anchors);
144 gsSparseSolver<>::LU solver;
148 m_coefs.resize((nsteps)*bsize,m_targetDim+1);
150 sol = solver.solve(rhs);
152 for (
index_t lam = 0; lam!=nsteps; ++lam)
154 gsMatrix<> tmp = sol.block(lam,0,1,m_targetDim * bsize);
155 m_coefs.block(lam * bsize,0,bsize,m_targetDim) = tmp.reshape(bsize,m_targetDim);
156 m_coefs.block(lam * bsize,m_targetDim,bsize,1) = sol(lam,m_targetDim*bsize) * ones;
162 m_fit = gsTensorBSpline<domainDim+1,T>(m_basis,
give(m_coefs));
167 std::vector<gsMatrix<T>> m_data;
169 gsVector<T> m_ptimes;
170 gsMultiBasis<T> m_bases;
173 mutable gsTensorBSplineBasis<domainDim+1,T> m_basis;
176 mutable gsTensorBSpline<domainDim+1,T> m_fit;
S give(S &x)
Definition: gsMemory.h:266
#define index_t
Definition: gsConfig.h:32
#define GISMO_ENSURE(cond, message)
Definition: gsDebug.h:102
#define GISMO_ASSERT(cond, message)
Definition: gsDebug.h:89