Skip to content
This repository has been archived by the owner on Mar 26, 2022. It is now read-only.

Commit

Permalink
Fix type bug
Browse files Browse the repository at this point in the history
  • Loading branch information
BreezeWhite committed Oct 4, 2021
1 parent f735562 commit 64569f9
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

End-to-end Optical Music Recognition system build on deep learning models and machine learning techniques.
Default to use **Onnxruntime** for model inference. If you want to use **tensorflow** for the inference,
run `export INFERENCE_WITH_TF=True` and make sure there is TF installed.
run `export INFERENCE_WITH_TF=true` and make sure there is TF installed.

![](figures/tabi_mix.jpg)

Expand Down
4 changes: 2 additions & 2 deletions oemer/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def get_bbox(data):


def get_center(bbox):
cen_y = round((bbox[1] + bbox[3]) / 2)
cen_x = round((bbox[0] + bbox[2]) / 2)
cen_y = int(round((bbox[1] + bbox[3]) / 2))
cen_x = int(round((bbox[0] + bbox[2]) / 2))
return cen_x, cen_y


Expand Down
3 changes: 2 additions & 1 deletion oemer/build_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ def get_rest(duration):

def get_chroma_pitch(pos, clef_type):
order = G_CLEF_POS_TO_PITCH if clef_type == ClefType.G_CLEF else F_CLEF_POS_TO_PITCH
pos = int(pos)
return order[pos%7] if pos >= 0 else order[pos%-7]


Expand Down Expand Up @@ -895,7 +896,7 @@ def decode_note(note, clef_type, is_chord=False, voice=1) -> Element:
alter = SubElement(pitch, 'alter')
octave = SubElement(pitch, 'octave')
alter.text = '0'
pos = note.staff_line_pos
pos = int(note.staff_line_pos)
if clef_type == ClefType.G_CLEF:
order = G_CLEF_POS_TO_PITCH
oct_offset = 4
Expand Down
8 changes: 5 additions & 3 deletions oemer/dewarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def connect_nearby_grid_group(gg_map, grid_groups, grid_map, grids, ref_count=8,
for i in range(max_step):
tar_x = (-i - 1) * step_size
cen_y = model.predict([[tar_x]])[0] # Interpolate y center
y = round(cen_y - h / 2)
y = int(round(cen_y - h / 2))
region = new_gg_map[y:y+h, end_x-step_size:end_x] # Area to check
unique, counts = np.unique(region, return_counts=True)
labels = set(unique) # Overlapped grid group IDs
Expand Down Expand Up @@ -202,6 +202,8 @@ def connect_nearby_grid_group(gg_map, grid_groups, grid_map, grids, ref_count=8,
max(gg.bbox[2], box[2]),
max(gg.bbox[3], box[3])
)
gg.bbox = [int(bb) for bb in gg.bbox]
box = [int(bb) for bb in box]
grids.append(grid)
new_gg_map[box[1]:box[3], box[0]:box[2]] = gg.id

Expand Down Expand Up @@ -236,7 +238,7 @@ def build_mapping(gg_map, min_width_ratio=0.4):
meta_idx = np.where(x==ux)[0]
sub_y = y[meta_idx]
cen_y = round(np.mean(sub_y))
coords_y[target_y, ux] = cen_y
coords_y[int(target_y), int(ux)] = cen_y
points.append((target_y, ux))

# Add corner case
Expand Down Expand Up @@ -267,7 +269,7 @@ def estimate_coords(staff_pred):
coords_y, points = build_mapping(new_gg_map)

logger.debug("Dewarping")
vals = coords_y[points[:, 0], points[:, 1]]
vals = coords_y[points[:, 0].astype(int), points[:, 1].astype(int)]
grid_x, grid_y = np.mgrid[0:gg_map.shape[0]:1, 0:gg_map.shape[1]:1]
coords_y = griddata(points, vals, (grid_x, grid_y), method='linear')

Expand Down
2 changes: 1 addition & 1 deletion oemer/note_group_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def search(cur_y, y_bound, step):
break
elif step < 0 and cur_y < y_bound:
break
pxs = group_map[cur_y, start_x:end_x]
pxs = group_map[int(cur_y), int(start_x):int(end_x)]
gids = set(np.unique(pxs))
if 0 in gids:
gids.remove(0)
Expand Down
8 changes: 4 additions & 4 deletions oemer/notehead_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ def __repr__(self):


def morph_notehead(pred, unit_size):
small_size = round(unit_size / 3)
small_size = int(round(unit_size / 3))
small_ker = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (small_size, small_size))
pred = cv2.erode(cv2.dilate(pred.astype(np.uint8), small_ker), small_ker)
size = (
round(unit_size*nhc.NOTEHEAD_MORPH_WIDTH_FACTOR),
round(unit_size*nhc.NOTEHEAD_MORPH_HEIGHT_FACTOR)
int(round(unit_size*nhc.NOTEHEAD_MORPH_WIDTH_FACTOR)),
int(round(unit_size*nhc.NOTEHEAD_MORPH_HEIGHT_FACTOR))
)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, size)
img = cv2.erode(pred.astype(np.uint8), kernel)
Expand Down Expand Up @@ -146,7 +146,7 @@ def check_bbox_size(bbox, noteheads, unit_size):
tmp_new.extend(check_bbox_size(box, noteheads, unit_size))
new_bbox = tmp_new
else:
num_notes = round(h / note_h)
num_notes = int(round(h / note_h))
if num_notes > 0:
sub_h = h // num_notes
for i in range(num_notes):
Expand Down
15 changes: 10 additions & 5 deletions oemer/rhythm_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def scan_dot(symbols, note_id_map, bbox, unit_size, min_count, max_count):
# Find the right most bound for scan the dot.
# Should have width less than unit_size, and can't
# touch the nearby note.
cur_scan_line = note_id_map[start_y:bbox[3], right_bound]
cur_scan_line = note_id_map[int(start_y):int(bbox[3]), int(right_bound)]
ids = set(np.unique(cur_scan_line))
if -1 in ids:
ids.remove(-1)
Expand All @@ -34,13 +34,13 @@ def scan_dot(symbols, note_id_map, bbox, unit_size, min_count, max_count):
break

left_bound = bbox[2] + round(unit_size*0.4)
dot_region = symbols[start_y:bbox[3], left_bound:right_bound]
dot_region = symbols[int(start_y):int(bbox[3]), int(left_bound):int(right_bound)]
pixels = np.sum(dot_region)
if min_count <= pixels <= max_count:
color = (255, random.randint(0, 255), random.randint(0, 255))
cv2.rectangle(dot_img, (left_bound, start_y), (right_bound, bbox[3]), color, 1)
cv2.rectangle(dot_img, (int(left_bound), int(start_y)), (int(right_bound), int(bbox[3])), color, 1)
msg = f"{min_count:.2f}/{pixels:.2f}/{max_count:.2f}"
cv2.putText(dot_img, msg, (bbox[0], bbox[3]+30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 1)
cv2.putText(dot_img, msg, (int(bbox[0]), int(bbox[3])+30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 1)
return True

# color = (255, random.randint(0, 255), random.randint(0, 255))
Expand Down Expand Up @@ -326,6 +326,11 @@ def scan_beam_flag(
min_width_ratio=0.25,
max_width_ratio=0.9):

start_x = int(start_x)
start_y = int(start_y)
end_x = int(end_x)
end_y = int(end_y)

cv2.line(beam_img, (start_x, start_y), (end_x, start_y), (66, 245, 212), 1, cv2.LINE_8)
cv2.line(beam_img, (start_x, end_y), (end_x, end_y), (66, 245, 212), 1, cv2.LINE_8)

Expand Down Expand Up @@ -558,7 +563,7 @@ def parse_rhythm(beam_map, map_info, agree_th=0.15):
)

#cv2.rectangle(beam_img, (gbox[0], gbox[1]), (gbox[2], gbox[3]), (255, 0, 255), 1)
cv2.putText(beam_img, str(count), (cen_x, gbox[3]+2), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 1)
cv2.putText(beam_img, str(count), (int(cen_x), int(gbox[3])+2), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 1)

# Assign note label
for nid in group.note_ids:
Expand Down
4 changes: 2 additions & 2 deletions oemer/symbol_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def gen_clefs(bboxes, labels):

def get_nearby_note_id(box, note_id_map):
cen_x, cen_y = get_center(box)
unit_size = round(get_unit_size(cen_x, cen_y))
unit_size = int(round(get_unit_size(cen_x, cen_y)))
nid = None
for x in range(box[2], box[2]+unit_size):
if note_id_map[cen_y, x] != -1:
Expand Down Expand Up @@ -413,7 +413,7 @@ def gen_rests(bboxes, labels):
rr.track = st1.track
rr.group = st1.group

unit_size = round(get_unit_size(*get_center(box)))
unit_size = int(round(get_unit_size(*get_center(box))))
dot_range = range(box[2]+1, box[2]+unit_size)
dot_region = symbols[box[1]:box[3], dot_range]
if 0 < np.sum(dot_region) < unit_size**2 / 7:
Expand Down

0 comments on commit 64569f9

Please sign in to comment.