-
Notifications
You must be signed in to change notification settings - Fork 3
/
spline.py
140 lines (121 loc) · 4.99 KB
/
spline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
def B_batch(x, grid, k=0, extend=True, device='cpu'):
'''
evaludate x on B-spline bases
Args:
-----
x : 2D torch.tensor
inputs, shape (number of splines, number of samples)
grid : 2D torch.tensor
grids, shape (number of splines, number of grid points)
k : int
the piecewise polynomial order of splines.
extend : bool
If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True
device : str
devicde
Returns:
--------
spline values : 3D torch.tensor
shape (number of splines, number of B-spline bases (coeffcients), number of samples). The numbef of B-spline bases = number of grid points + k - 1.
Example
-------
>>> num_spline = 5
>>> num_sample = 100
>>> num_grid_interval = 10
>>> k = 3
>>> x = torch.normal(0,1,size=(num_spline, num_sample))
>>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1))
>>> B_batch(x, grids, k=k).shape
torch.Size([5, 13, 100])
'''
# x shape: (size, x); grid shape: (size, grid)
def extend_grid(grid, k_extend=0):
# pad k to left and right
# grid shape: (batch, grid)
h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
for i in range(k_extend):
grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)
grid = grid.to(device)
return grid
if extend == True:
grid = extend_grid(grid, k_extend=k)
grid = grid.unsqueeze(dim=2).to(device)
x = x.unsqueeze(dim=1).to(device)
if k == 0:
value = (x >= grid[:, :-1]) * (x < grid[:, 1:])
else:
B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device)
value = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)]) * B_km1[:, :-1] + (
grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * B_km1[:, 1:]
return value
def coef2curve(x_eval, grid, coef, k, device="cpu"):
'''
converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).
Args:
-----
x_eval : 2D torch.tensor)
shape (number of splines, number of samples)
grid : 2D torch.tensor)
shape (number of splines, number of grid points)
coef : 2D torch.tensor)
shape (number of splines, number of coef params). number of coef params = number of grid intervals + k
k : int
the piecewise polynomial order of splines.
device : str
devicde
Returns:
--------
y_eval : 2D torch.tensor
shape (number of splines, number of samples)
Example
-------
>>> num_spline = 5
>>> num_sample = 100
>>> num_grid_interval = 10
>>> k = 3
>>> x_eval = torch.normal(0,1,size=(num_spline, num_sample))
>>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1))
>>> coef = torch.normal(0,1,size=(num_spline, num_grid_interval+k))
>>> coef2curve(x_eval, grids, coef, k=k).shape
torch.Size([5, 100])
'''
# x_eval: (size, batch), grid: (size, grid), coef: (size, coef)
# coef: (size, coef), B_batch: (size, coef, batch), summer over coef
if coef.dtype != x_eval.dtype:
coef = coef.to(x_eval.dtype)
y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device))
return y_eval
def curve2coef(x_eval, y_eval, grid, k, device="cpu"):
'''
converting B-spline curves to B-spline coefficients using least squares.
Args:
-----
x_eval : 2D torch.tensor
shape (number of splines, number of samples)
y_eval : 2D torch.tensor
shape (number of splines, number of samples)
grid : 2D torch.tensor
shape (number of splines, number of grid points)
k : int
the piecewise polynomial order of splines.
device : str
devicde
Example
-------
>>> num_spline = 5
>>> num_sample = 100
>>> num_grid_interval = 10
>>> k = 3
>>> x_eval = torch.normal(0,1,size=(num_spline, num_sample))
>>> y_eval = torch.normal(0,1,size=(num_spline, num_sample))
>>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1))
torch.Size([5, 13])
'''
# x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar
mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1)
# coef = torch.linalg.lstsq(mat, y_eval.unsqueeze(dim=2)).solution[:, :, 0]
coef = torch.linalg.lstsq(mat.to(device), y_eval.unsqueeze(dim=2).to(device),
driver='gelsy' if device == 'cpu' else 'gels').solution[:, :, 0]
return coef.to(device)