-
Notifications
You must be signed in to change notification settings - Fork 1
/
time_shifter.py
185 lines (132 loc) · 4.72 KB
/
time_shifter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#functions / classes to generate timing shifts in patterns
from random import choice
import numpy as np
def lex_partitions(n):
"""IntegerPartitions.py
Generate and manipulate partitions of integers into sums of integers.
D. Eppstein, August 2005.
"""
"""Similar to revlex_partitions, but in lexicographic order."""
if n == 0:
yield []
if n <= 0:
return
for p in lex_partitions(n-1):
p.append(1)
yield p
p.pop()
if len(p) == 1 or (len(p) > 1 and p[-1] < p[-2]):
p[-1] += 1
yield p
p[-1] -= 1
class partition_shifter():
#for an n-spot target pattern, generate a vector of timing shifts for each spot
#probabilistically but uniformly across a range of total absolute shift
parts_sampling = {}
parts_sampling[1] = {2:0,
4:0.3,
6:0.3,
8:0.3,
10:0.05,
12:0,
14:0,
16:0.05}
parts_sampling[2] = {2:0,
4:0,
6:0,
8:0,
10:0.25,
12:0.25,
14:0.25,
16:0.25}
parts_sampling[3] = {2:0,
4:0.125,
6:0.125,
8:0.125,
10:0.125,
12:0.125,
14:0.125,
16:0.125,
18:0.125}
#target_t = [10,50,90,130,170,210]
#old_target = [(10, 90), (50, 130),(90, 170), (130, 210), (170, 250), (210, 290)]
def __init__(self, sample_number, old_target):
self.parts={}
self.tshifts=[]
self.tweights=[]
self.parts_ratio=self.parts_sampling[sample_number]
self.gen_partition_doublets()
self.old_target=old_target
self.target_t = [t[0] for t in old_target]
def gen_partition_doublets(self):
#Resolution 20ms. Oversampling of some small shifts,
#and limited sampling of some large shifts
parts={}
nSpots=6
maxSingleShift = 4
shift_max = max(self.parts_ratio.keys())
for shift in range(2,shift_max/2 + 1):
parts[shift*2]=[]
for x in lex_partitions(shift):
if (x[0]*2) > maxSingleShift:
break #don't accept single spot shift > maxSingleShift
if 1 < len(x) <= nSpots:
y = x + [0] * (nSpots - len(x)) #zero pad
y = list(np.array(y)*2) #multiply by 2
parts[shift*2].append(y)
parts[2]=[[1,1,0,0,0,0]]
tweights=[]
tshifts=[]
for t in parts:
if self.parts_ratio[t] == 0:
continue
tweights.append(self.parts_ratio[t])
tshifts.append(t)
self.tweights = tweights
self.tshifts = tshifts
self.parts = parts
def select_t(self):
return np.random.choice(self.tshifts,1,p=self.tweights)[0]
def select_part(self):
shift=self.select_t()
flag = 1
while flag:
#keep looping until allowable partition is found.
#all spots should end up with non-negative timing.
if len(self.parts[shift])==1:
shift_vec = self.parts[shift][0]
else:
shift_vec=choice(self.parts[shift])
shift_vec = np.random.permutation([choice([-10,10]) *s for s in shift_vec])
probe_t=[]
for t1,t2 in zip(self.target_t,shift_vec):
probe_t.append(t1+t2)
if sum(np.array(probe_t)<0) > 0:
probe_t = [max(0,p) for p in probe_t]
flag = 1
else:
flag = 0
return shift_vec
def get_shift_map(self):
shift_vec=self.select_part()
shift_map={}
print shift_vec
for t_old,shift_new in zip(self.old_target,shift_vec):
t_new = list(np.array(t_old)+shift_new)
shift_map[t_old]=t_new
return shift_map
class duration_shifter():
old_target = [(10, 90), (50, 130),(90, 170), (130, 210), (170, 250), (210, 290)]
shift_types={1:[-40,-60,+40,+60]} #extend or reduce duration without changing onset
def __init__(self,shift_scheme):
self.shift_scheme = shift_scheme
self.shift_list = self.shift_types[shift_scheme]
def get_shift_map(self):
shift_map={}
for t in self.old_target:
shift_map[t]=t
old_t=choice(shift_map.keys())
shift=choice(self.shift_list)
new_t=[old_t[0],old_t[1]+shift]
shift_map[old_t]=new_t
return shift_map