diff --git a/s4torch/layer.py b/s4torch/layer.py index e92fa35..cad0888 100644 --- a/s4torch/layer.py +++ b/s4torch/layer.py @@ -105,9 +105,9 @@ def __init__(self, d_model: int, n: int, l_max: int) -> None: self.l_max = l_max p, q, lambda_ = map(lambda t: t.type(torch.complex64), _make_p_q_lambda(n)) - self._p = nn.Parameter(as_real(p)) - self._q = nn.Parameter(as_real(q)) - self._lambda_ = nn.Parameter(as_real(lambda_).unsqueeze(0).unsqueeze(1)) + self.p = nn.Parameter(p) + self.q = nn.Parameter(q) + self.lambda_ = nn.Parameter(lambda_.unsqueeze(0).unsqueeze(1)) self.register_buffer( "omega_l", @@ -121,38 +121,18 @@ def __init__(self, d_model: int, n: int, l_max: int) -> None: ), ) - self._B = nn.Parameter( - as_real(init.xavier_normal_(torch.empty(d_model, n, dtype=torch.complex64))) + self.B = nn.Parameter( + init.xavier_normal_(torch.empty(d_model, n, dtype=torch.complex64)) ) - self._Ct = nn.Parameter( - as_real(init.xavier_normal_(torch.empty(d_model, n, dtype=torch.complex64))) + self.Ct = nn.Parameter( + init.normal_(torch.empty(d_model, n, dtype=torch.complex64), std=1.0) ) - self.D = nn.Parameter(torch.ones(1, 1, d_model)) + self.D = nn.Parameter(init.normal_(torch.empty(1, 1, d_model), std=1.0)) self.log_step = nn.Parameter(_log_step_initializer(torch.rand(d_model))) def extra_repr(self) -> str: return f"d_model={self.d_model}, n={self.n}, l_max={self.l_max}" - @property - def p(self) -> torch.Tensor: - return torch.view_as_complex(self._p) - - @property - def q(self) -> torch.Tensor: - return torch.view_as_complex(self._q) - - @property - def lambda_(self) -> torch.Tensor: - return torch.view_as_complex(self._lambda_) - - @property - def B(self) -> torch.Tensor: - return torch.view_as_complex(self._B) - - @property - def Ct(self) -> torch.Tensor: - return torch.view_as_complex(self._Ct) - def _compute_roots(self) -> torch.Tensor: a0, a1 = self.Ct.conj(), self.q.conj() b0, b1 = self.B, self.p