-
Notifications
You must be signed in to change notification settings - Fork 1
/
plotting.py
120 lines (99 loc) · 4.81 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
import matplotlib.pyplot as plt
from data_utils import GroundTruth, Landmark, Measurement
from state import State
class Plotter():
"""Plotter class that contains the code for plotting the particle filter in action."""
PLOT_REALTIME_PARTICLES, PLOT_REALTIME_MEAN, PLOT_FINAL_PATH = 0, 1, 2
def __init__(self, n_particles, landmarks, plot_mode=PLOT_REALTIME_PARTICLES):
# configure plotting
self.plot_mode = plot_mode
if plot_mode == self.PLOT_REALTIME_MEAN or plot_mode == self.PLOT_REALTIME_PARTICLES:
plt.ion()
self.fig, self.ax = plt.subplots(figsize=(8, 8))
# landmarks
self.ax.plot(landmarks[:, Landmark.X], landmarks[:, Landmark.Y],
marker='o', markersize=4, linestyle='None', color='#0047AB')
# visible landmarks (seen by the robot)
self.visible_landmarks, = self.ax.plot(
[], [], marker='o', markersize=6, linestyle='None', color='#FF0500')
if plot_mode == self.PLOT_REALTIME_MEAN:
# ground truth
self.truth_dot, = self.ax.plot(
[], [], marker='o', markersize=6, linestyle='None', color='#228B22')
self.truth_arrow = self.ax.arrow(0, 0, 0, 0, color='#228B22')
# particles
self.particles_dots, = self.ax.plot(
[], [], marker='o', markersize=2, linestyle='None', color='#FF5F15')
self.mean_arrow = self.ax.arrow(0, 0, 0, 0, color='#222222')
else:
# particles
self.particles_dots, = self.ax.plot(
[], [], marker='o', markersize=2, linestyle='None', color='#FF5F15')
self.particles_arrows = [self.ax.arrow(0, 0, 0, 0, color='#222222')
for _ in range(n_particles)]
# ground truth
self.truth_dot, = self.ax.plot(
[], [], marker='o', markersize=6, linestyle='None', color='#228B22')
self.truth_arrow = self.ax.arrow(0, 0, 0, 0, color='#228B22')
def update(self, means, state, z, tru, landmarks, i):
# plot results
if self.plot_mode == self.PLOT_REALTIME_MEAN:
# get the mean x,y,theta (this is just for ease of understanding)
x = means[i, State.X]
y = means[i, State.Y]
theta = means[i, State.HEADING]
# set the x,y point of the mean
self.particles_dots.set_data(means[i, State.X], means[i, State.Y])
# plot the direction
dx = np.cos(theta)*0.5
dy = np.sin(theta)*0.5
self.mean_arrow.set_data(x=x, y=y, dx=dx, dy=dy)
elif self.plot_mode == self.PLOT_REALTIME_PARTICLES:
# plot all the points of our state
self.particles_dots.set_data(state[:, State.X], state[:, State.Y])
# plot all the arrows
for i, v in enumerate(self.particles_arrows):
x = state[i, State.X]
y = state[i, State.Y]
theta = state[i, State.HEADING]
dx = np.cos(theta)*0.5
dy = np.sin(theta)*0.5
v.set_data(x=x, y=y, dx=dx, dy=dy)
else:
pass
# if we're plotting in real time
if self.plot_mode == self.PLOT_REALTIME_MEAN or self.plot_mode == self.PLOT_REALTIME_PARTICLES:
# if there are measurements
if z.size > 0:
# get the landmarks
landmarks = np.matrix(
[landmarks[np.where(landmarks[:, Landmark.SUBJECT] == x[Measurement.SUBJECT])[0][0]] for x in z])
# plot them
self.visible_landmarks.set_data(
landmarks[:, Landmark.X], landmarks[:, Landmark.Y])
# if there's ground truth data
if tru.size > 0:
x = tru[0, GroundTruth.X]
y = tru[0, GroundTruth.Y]
dx = np.cos(tru[0, GroundTruth.H])
dy = np.sin(tru[0, GroundTruth.H])
# set the point
self.truth_dot.set_data(x, y)
# set the arrow
self.truth_arrow.set_data(x=x, y=y, dx=dx, dy=dy)
plt.xlim(-6, 9)
plt.ylim(-7.5, 7.5)
plt.pause(1e-3)
def plot(self, means, ground_truth):
if self.plot_mode == self.PLOT_FINAL_PATH:
plt.plot(means[:, State.X], means[:, State.Y], 'b')
# Adds to current figure
plt.plot(ground_truth[:, GroundTruth.X],
ground_truth[:, GroundTruth.Y], 'r')
plt.xlim((-10, 10))
plt.ylim((-10, 10))
plt.xlabel('x [m]')
plt.ylabel('y [m]')
plt.legend(['model', 'truth'])
plt.show()