From 18b6f72e55835830460b1f7229af645acb324770 Mon Sep 17 00:00:00 2001 From: Dan Fu Date: Wed, 15 Jun 2022 21:13:27 -0700 Subject: [PATCH 1/6] S4 simple --- configs/experiment/s4-simple-cifar.yaml | 34 ++++ configs/model/layer/s4_simple.yaml | 32 +++ src/models/sequence/ss/s4_simple.py | 254 ++++++++++++++++++++++++ src/utils/registry.py | 3 +- 4 files changed, 322 insertions(+), 1 deletion(-) create mode 100644 configs/experiment/s4-simple-cifar.yaml create mode 100644 configs/model/layer/s4_simple.yaml create mode 100644 src/models/sequence/ss/s4_simple.py diff --git a/configs/experiment/s4-simple-cifar.yaml b/configs/experiment/s4-simple-cifar.yaml new file mode 100644 index 00000000..453598ef --- /dev/null +++ b/configs/experiment/s4-simple-cifar.yaml @@ -0,0 +1,34 @@ +# @package _global_ +defaults: + - /pipeline: cifar + - /model: s4 + - override /model/layer: s4_simple + - override /scheduler: cosine_warmup + +model: + dropout: 0.1 + n_layers: 4 + d_model: 128 + prenorm: false + layer: + #scaling: linear + d_state: 64 + lr: 0.001 + postact: glu + #bidirectional: false + +loader: + batch_size: 50 + +optimizer: + lr: 0.01 + weight_decay: 0.01 + +trainer: + max_epochs: 100 + +scheduler: + num_training_steps: 100000 + +train: + seed: 1111 diff --git a/configs/model/layer/s4_simple.yaml b/configs/model/layer/s4_simple.yaml new file mode 100644 index 00000000..e436982f --- /dev/null +++ b/configs/model/layer/s4_simple.yaml @@ -0,0 +1,32 @@ +_name_: s4_simple +d_state: 64 +channels: 1 +bidirectional: false +activation: gelu +postact: null +initializer: null +weight_norm: false +#hyper_act: null +dropout: ${..dropout} # Same as null +#measure: legs +#rank: 1 +dt_min: 0.001 +dt_max: 0.1 +# trainable: +# dt: true +# A: true +# P: true +# B: true +lr: 0.001 +# mode: nplr +# n_ssm: 1 +# resample: false +#deterministic: false # Special C init +#l_max: ${oc.select:dataset.__l_max,null} # Grab dataset length if exists, otherwise set to 1 and kernel will automatically resize +#verbose: true +use_initial: true +learn_a: true +learn_theta: true +trap_rule: false +zero_order_hold: true +theta_scale: false \ No newline at end of file diff --git a/src/models/sequence/ss/s4_simple.py b/src/models/sequence/ss/s4_simple.py new file mode 100644 index 00000000..c7b2bae1 --- /dev/null +++ b/src/models/sequence/ss/s4_simple.py @@ -0,0 +1,254 @@ +import torch +import torch.nn as nn +from einops import rearrange, repeat +import opt_einsum as oe +from src.models.nn import Activation, LinearActivation + +import math + +# Replacement for Dropout in PyTorch 1.11.0 +class DropoutNd(nn.Module): + def __init__(self, p: float = 0.5, tie=True): + """ tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) + For some reason tie=False is dog slow, prob something wrong with torch.distribution + """ + super().__init__() + if p < 0 or p >= 1: + raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p)) + self.p = p + self.tie = tie + self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) + + def forward(self, X): + """ X: (batch, dim, lengths...) """ + if self.training: + # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) + mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape + # mask = self.binomial.sample(mask_shape) + mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p + return X * mask * (1.0/(1-self.p)) + return X + +class OurModule(nn.Module): + def __init__(self): super().__init__() + + def register(self, name, tensor, trainable=False, lr=None, wd=None): + """Utility method: register a tensor as a buffer or trainable parameter""" + + if trainable: + self.register_parameter(name, nn.Parameter(tensor)) + else: + self.register_buffer(name, tensor) + + optim = {} + if trainable and lr is not None: optim["lr"] = lr + if trainable and wd is not None: optim["weight_decay"] = wd + if len(optim) > 0: setattr(getattr(self, name), "_optim", optim) + +# +# This is intended to match np.convolve(x,w)[:len(w)] +# That is, (u \ast v)[k] = sum_{j} u[k-j]v[j] +# Here y = (u \ask v) on return. +# We assume the inputs are: +# u (B H L) +# v (C H L) +# and we want to produce y that is (B C H L) +# +def fft_conv(u,v): + L = u.shape[-1] + u_f = torch.fft.rfft(u, n=2*L) # (B H L) + v_f = torch.fft.rfft(v, n=2*L) # (C H L) + + y_f = oe.contract('bhl,chl->bchl', u_f, v_f) + y = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L) + return y + +class SimpleS4(OurModule): + def __init__(self, + nHippos, + d_state=64, + channels=1, + use_initial=True, # Use the initial state? + zero_order_hold=True, # Use zero-order hold approximation + trap_rule=False, + dt_min=0.001, + dt_max=0.1, + lr=None, # Hook to set LR of SSM parameters differently + learn_a=True, + learn_theta=True, + theta_scale=False, + **kernel_args,): # Use the trapezoid rule + super().__init__() + # H is number of hippos + # D is the dimension (also shockingly n other places) + # B is the batch + # L is the length + self.h = nHippos + self.d = d_state // 2 + self.channels = channels + self.use_initial = use_initial + self.zero_order_hold = zero_order_hold + # + # Use the trapezoid rule correct or just do zero-order hold. + self.trap_rule = trap_rule + + _fp = (self.channels, self.h, self.d) + + # Chebyshev initialization + h_scale = torch.exp(torch.arange(self.h)/self.h * math.log(dt_max/dt_min)) + angles = torch.arange(self.d)*torch.pi + theta_scale = h_scale if theta_scale else torch.ones(self.h) + theta = oe.contract('c,h,d->chd', torch.ones(self.channels), h_scale, angles) + a = -repeat(h_scale, 'h -> c h d', c=self.channels, d=self.d) + #a = -oe.contract('c,h,d->chd', torch.ones(self.channels), _log_T, + # torch.ones(self.d)) + + self.register("theta", theta,learn_theta,lr=lr, wd=None) + self.register("a", a, learn_a,lr=lr, wd=None) + # The other maps + self.D = nn.Parameter(torch.randn(channels, self.h)) + + if use_initial: + self.b = nn.Parameter(torch.randn(*_fp)) + self.c = nn.Parameter(torch.randn(*_fp)) + self.x0 = nn.Parameter(torch.randn(channels, self.h, self.d)) + else: + # This is an optimization that we combine q = c * b + # It's as if we're setting x0 = 0. + self.q = nn.Parameter(torch.randn(*_fp)) + + def zoh_method(self, u): + l = u.size(-1) + T = 1/(l-1) + zk = T*torch.arange(u.size(-1), device=u.device).view(1,1,-1,1) + ls = torch.complex(-self.a.abs(), self.theta) + term_0 = (torch.exp(ls*T) - 1)/ls + base_term = (2*term_0.unsqueeze(2)*torch.exp(ls.unsqueeze(2)* zk)).real + q = self.b*self.c if self.use_initial else self.q + f = (q.unsqueeze(2)*base_term).sum(-1) + y = fft_conv(u,f) + y = y + oe.contract('bhl,ch->bchl', u, self.D) + if self.use_initial: + # This the cosine formula from the note + cos_term = 2*T*torch.exp(-self.a.abs().unsqueeze(2) * zk)*torch.cos( self.theta.unsqueeze(2) * zk) + y = y + (2*(self.c*self.x0).unsqueeze(2)*cos_term).sum(-1) + return rearrange(y, 'b c h l-> b (c h) l') # flatten the channels. + + def quadrature_method(self, u): + # The input is now Batch x Hippos x Length + l = u.size(-1) + T = 1/(l-1) # the step size + zk = T*torch.arange(l, device=u.device).view(1,1,-1,1) + #T = torch.exp(self.log_T).to(u.device).unsqueeze(-1).unsqueeze(-1) + #zk = T*torch.arange(l, device=u.device).view(1,1,-1,1) + # q and a are both C x H x D + # zk is of length l we want a C x H x L matrix + # From the note, we have + # f[k] = 2 sum_{j=1}^{d} q_j e^{a_j z_k} cos( z_k * theta_j ) + # we construct the body of the sum + base_term = 2*T*torch.exp(-self.a.abs().unsqueeze(2) * zk)*torch.cos( self.theta.unsqueeze(2) * zk) + q = self.b*self.c if self.use_initial else self.q + f = (q.unsqueeze(2)*base_term).sum(-1) + + # after summing f it is now an C H L matrix + g = u # this is a B H L matrix + # we want to convolve on L and produce a B H C L + # + y = fft_conv(g,f) + if self.trap_rule: + y = y - T*(oe.contract('ch,bhl -> bchl', f[:,:,0], g) + oe.contract('chl,bh -> bchl', f, g[:,:,0]))/2 + + # Add in the skip connection with per-channel D matrix + y = y + oe.contract('bhl,ch->bchl', u, self.D) + # Add back the initial state + if self.use_initial: + y = y + (2*(self.c*self.x0).unsqueeze(2)*base_term).sum(-1) + return rearrange(y, 'b c h l-> b (c h) l') # flatten the channels. + + def forward(self, u): + return self.zoh_method(u) if self.zero_order_hold else self.quadrature_method(u) + +# Below here are standard wrapper classes to handle +# (1) Non-linearity +# (2) Integration with the Hippo Code base +class NonLinear(nn.Module): + def __init__(self, h, channels, + ln=False, # Extra normalization + transposed=True, + dropout=0.0, + postact=None, # activation after FF + activation='gelu', # activation in between SS and FF + initializer=None, # initializer on FF + weight_norm=False, # weight normalization on FF + ): + super().__init__() + dropout_fn = DropoutNd # nn.Dropout2d bugged in PyTorch 1.11 + dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + #norm = Normalization(h*channels, transposed=transposed) if ln else nn.Identity() + + activation_fn = Activation(activation) + + output_linear = LinearActivation( + h*channels, + h, + transposed=transposed, + initializer=initializer, + activation=postact, + activate=True, + weight_norm=weight_norm, + ) + #self.f = nn.Sequential(activation_fn, dropout, norm, output_linear) + self.f = nn.Sequential(activation_fn, dropout, output_linear) + def forward(self,x): # Always (B H L) + return self.f(x) + +class SimpleS4Wrapper(nn.Module): + def __init__( + self, + d_model, + d_state=64, + channels=1, + bidirectional=False, + dropout=0.0, + transposed=True, # axis ordering (B, L, D) or (B, D, L) + ln=True, # IGNORED: Extra normalization + postact=None, # activation after FF + activation='gelu', # activation in between SS and FF + initializer=None, # initializer on FF + weight_norm=False, # weight normalization on FF + linear=False, + # SSM Kernel arguments + **kernel_args, + ): + super().__init__() + self.h = d_model + self.d = d_state + self.channels = channels + #self.shift = shift + #self.linear = linear + self.out_d = self.h + self.transposed = transposed + self.bidirectional = bidirectional + assert not bidirectional, f"Bidirectional NYI" + self.s4 = SimpleS4(nHippos=d_model, d_state=d_state, + channels=channels, **kernel_args) + # the mapping + # We transpose if it's not in the forward. + nl = NonLinear(self.h, channels=self.channels, ln=ln, # Extra normalization + dropout=dropout, postact=postact, activation=activation, transposed=True, + initializer=initializer, weight_norm=weight_norm) + self.out = nn.Identity() if linear else nl + + def forward(self, u, state=None): + # u: (B H L) if self.transposed else (B L H) + if not self.transposed: u = u.transpose(-1, -2) + # We only pass BHL, and it is as if transposed is True. + ret = self.out(self.s4(u)) + if not self.transposed: ret = ret.transpose(-1, -2) + return ret, state + + @property + def d_state(self): return self.h * self.d + + @property + def d_output(self): return self.out_d \ No newline at end of file diff --git a/src/utils/registry.py b/src/utils/registry.py index d8d0628a..32a81f68 100644 --- a/src/utils/registry.py +++ b/src/utils/registry.py @@ -71,7 +71,8 @@ "conv1d": "src.models.sequence.conv1d.Conv1d", "attsimp": "src.models.sequence.mha.AttentionSimple", "performer": "src.models.sequence.attention.linear.Performer", - "s4_2dconv": "src.models.sequence.ss.s4_2dconv.S42DConv" + "s4_2dconv": "src.models.sequence.ss.s4_2dconv.S42DConv", + "s4_simple": "src.models.sequence.ss.s4_simple.SimpleS4Wrapper", # 'packedrnn': 'models.sequence.rnns.packedrnn.PackedRNN', } From feefb6581a81c4f9c3de13d95e974949e4f230ef Mon Sep 17 00:00:00 2001 From: Dan Fu Date: Fri, 17 Jun 2022 13:33:30 -0700 Subject: [PATCH 2/6] Refactor S4 Simple --- configs/model/layer/s4_simple.yaml | 4 +- src/models/sequence/ss/s4_simple.py | 254 ------------------ src/models/sequence/ss/s4_simple/README.md | 20 ++ src/models/sequence/ss/s4_simple/s4_simple.py | 107 ++++++++ .../sequence/ss/s4_simple/s4_wrapper.py | 90 +++++++ src/models/sequence/ss/s4_simple/utils.py | 60 +++++ src/utils/registry.py | 2 +- 7 files changed, 280 insertions(+), 257 deletions(-) delete mode 100644 src/models/sequence/ss/s4_simple.py create mode 100644 src/models/sequence/ss/s4_simple/README.md create mode 100644 src/models/sequence/ss/s4_simple/s4_simple.py create mode 100644 src/models/sequence/ss/s4_simple/s4_wrapper.py create mode 100644 src/models/sequence/ss/s4_simple/utils.py diff --git a/configs/model/layer/s4_simple.yaml b/configs/model/layer/s4_simple.yaml index e436982f..9caa69e9 100644 --- a/configs/model/layer/s4_simple.yaml +++ b/configs/model/layer/s4_simple.yaml @@ -24,9 +24,9 @@ lr: 0.001 #deterministic: false # Special C init #l_max: ${oc.select:dataset.__l_max,null} # Grab dataset length if exists, otherwise set to 1 and kernel will automatically resize #verbose: true -use_initial: true +use_initial: false learn_a: true -learn_theta: true +learn_theta: false trap_rule: false zero_order_hold: true theta_scale: false \ No newline at end of file diff --git a/src/models/sequence/ss/s4_simple.py b/src/models/sequence/ss/s4_simple.py deleted file mode 100644 index c7b2bae1..00000000 --- a/src/models/sequence/ss/s4_simple.py +++ /dev/null @@ -1,254 +0,0 @@ -import torch -import torch.nn as nn -from einops import rearrange, repeat -import opt_einsum as oe -from src.models.nn import Activation, LinearActivation - -import math - -# Replacement for Dropout in PyTorch 1.11.0 -class DropoutNd(nn.Module): - def __init__(self, p: float = 0.5, tie=True): - """ tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) - For some reason tie=False is dog slow, prob something wrong with torch.distribution - """ - super().__init__() - if p < 0 or p >= 1: - raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p)) - self.p = p - self.tie = tie - self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) - - def forward(self, X): - """ X: (batch, dim, lengths...) """ - if self.training: - # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) - mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape - # mask = self.binomial.sample(mask_shape) - mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p - return X * mask * (1.0/(1-self.p)) - return X - -class OurModule(nn.Module): - def __init__(self): super().__init__() - - def register(self, name, tensor, trainable=False, lr=None, wd=None): - """Utility method: register a tensor as a buffer or trainable parameter""" - - if trainable: - self.register_parameter(name, nn.Parameter(tensor)) - else: - self.register_buffer(name, tensor) - - optim = {} - if trainable and lr is not None: optim["lr"] = lr - if trainable and wd is not None: optim["weight_decay"] = wd - if len(optim) > 0: setattr(getattr(self, name), "_optim", optim) - -# -# This is intended to match np.convolve(x,w)[:len(w)] -# That is, (u \ast v)[k] = sum_{j} u[k-j]v[j] -# Here y = (u \ask v) on return. -# We assume the inputs are: -# u (B H L) -# v (C H L) -# and we want to produce y that is (B C H L) -# -def fft_conv(u,v): - L = u.shape[-1] - u_f = torch.fft.rfft(u, n=2*L) # (B H L) - v_f = torch.fft.rfft(v, n=2*L) # (C H L) - - y_f = oe.contract('bhl,chl->bchl', u_f, v_f) - y = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L) - return y - -class SimpleS4(OurModule): - def __init__(self, - nHippos, - d_state=64, - channels=1, - use_initial=True, # Use the initial state? - zero_order_hold=True, # Use zero-order hold approximation - trap_rule=False, - dt_min=0.001, - dt_max=0.1, - lr=None, # Hook to set LR of SSM parameters differently - learn_a=True, - learn_theta=True, - theta_scale=False, - **kernel_args,): # Use the trapezoid rule - super().__init__() - # H is number of hippos - # D is the dimension (also shockingly n other places) - # B is the batch - # L is the length - self.h = nHippos - self.d = d_state // 2 - self.channels = channels - self.use_initial = use_initial - self.zero_order_hold = zero_order_hold - # - # Use the trapezoid rule correct or just do zero-order hold. - self.trap_rule = trap_rule - - _fp = (self.channels, self.h, self.d) - - # Chebyshev initialization - h_scale = torch.exp(torch.arange(self.h)/self.h * math.log(dt_max/dt_min)) - angles = torch.arange(self.d)*torch.pi - theta_scale = h_scale if theta_scale else torch.ones(self.h) - theta = oe.contract('c,h,d->chd', torch.ones(self.channels), h_scale, angles) - a = -repeat(h_scale, 'h -> c h d', c=self.channels, d=self.d) - #a = -oe.contract('c,h,d->chd', torch.ones(self.channels), _log_T, - # torch.ones(self.d)) - - self.register("theta", theta,learn_theta,lr=lr, wd=None) - self.register("a", a, learn_a,lr=lr, wd=None) - # The other maps - self.D = nn.Parameter(torch.randn(channels, self.h)) - - if use_initial: - self.b = nn.Parameter(torch.randn(*_fp)) - self.c = nn.Parameter(torch.randn(*_fp)) - self.x0 = nn.Parameter(torch.randn(channels, self.h, self.d)) - else: - # This is an optimization that we combine q = c * b - # It's as if we're setting x0 = 0. - self.q = nn.Parameter(torch.randn(*_fp)) - - def zoh_method(self, u): - l = u.size(-1) - T = 1/(l-1) - zk = T*torch.arange(u.size(-1), device=u.device).view(1,1,-1,1) - ls = torch.complex(-self.a.abs(), self.theta) - term_0 = (torch.exp(ls*T) - 1)/ls - base_term = (2*term_0.unsqueeze(2)*torch.exp(ls.unsqueeze(2)* zk)).real - q = self.b*self.c if self.use_initial else self.q - f = (q.unsqueeze(2)*base_term).sum(-1) - y = fft_conv(u,f) - y = y + oe.contract('bhl,ch->bchl', u, self.D) - if self.use_initial: - # This the cosine formula from the note - cos_term = 2*T*torch.exp(-self.a.abs().unsqueeze(2) * zk)*torch.cos( self.theta.unsqueeze(2) * zk) - y = y + (2*(self.c*self.x0).unsqueeze(2)*cos_term).sum(-1) - return rearrange(y, 'b c h l-> b (c h) l') # flatten the channels. - - def quadrature_method(self, u): - # The input is now Batch x Hippos x Length - l = u.size(-1) - T = 1/(l-1) # the step size - zk = T*torch.arange(l, device=u.device).view(1,1,-1,1) - #T = torch.exp(self.log_T).to(u.device).unsqueeze(-1).unsqueeze(-1) - #zk = T*torch.arange(l, device=u.device).view(1,1,-1,1) - # q and a are both C x H x D - # zk is of length l we want a C x H x L matrix - # From the note, we have - # f[k] = 2 sum_{j=1}^{d} q_j e^{a_j z_k} cos( z_k * theta_j ) - # we construct the body of the sum - base_term = 2*T*torch.exp(-self.a.abs().unsqueeze(2) * zk)*torch.cos( self.theta.unsqueeze(2) * zk) - q = self.b*self.c if self.use_initial else self.q - f = (q.unsqueeze(2)*base_term).sum(-1) - - # after summing f it is now an C H L matrix - g = u # this is a B H L matrix - # we want to convolve on L and produce a B H C L - # - y = fft_conv(g,f) - if self.trap_rule: - y = y - T*(oe.contract('ch,bhl -> bchl', f[:,:,0], g) + oe.contract('chl,bh -> bchl', f, g[:,:,0]))/2 - - # Add in the skip connection with per-channel D matrix - y = y + oe.contract('bhl,ch->bchl', u, self.D) - # Add back the initial state - if self.use_initial: - y = y + (2*(self.c*self.x0).unsqueeze(2)*base_term).sum(-1) - return rearrange(y, 'b c h l-> b (c h) l') # flatten the channels. - - def forward(self, u): - return self.zoh_method(u) if self.zero_order_hold else self.quadrature_method(u) - -# Below here are standard wrapper classes to handle -# (1) Non-linearity -# (2) Integration with the Hippo Code base -class NonLinear(nn.Module): - def __init__(self, h, channels, - ln=False, # Extra normalization - transposed=True, - dropout=0.0, - postact=None, # activation after FF - activation='gelu', # activation in between SS and FF - initializer=None, # initializer on FF - weight_norm=False, # weight normalization on FF - ): - super().__init__() - dropout_fn = DropoutNd # nn.Dropout2d bugged in PyTorch 1.11 - dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() - #norm = Normalization(h*channels, transposed=transposed) if ln else nn.Identity() - - activation_fn = Activation(activation) - - output_linear = LinearActivation( - h*channels, - h, - transposed=transposed, - initializer=initializer, - activation=postact, - activate=True, - weight_norm=weight_norm, - ) - #self.f = nn.Sequential(activation_fn, dropout, norm, output_linear) - self.f = nn.Sequential(activation_fn, dropout, output_linear) - def forward(self,x): # Always (B H L) - return self.f(x) - -class SimpleS4Wrapper(nn.Module): - def __init__( - self, - d_model, - d_state=64, - channels=1, - bidirectional=False, - dropout=0.0, - transposed=True, # axis ordering (B, L, D) or (B, D, L) - ln=True, # IGNORED: Extra normalization - postact=None, # activation after FF - activation='gelu', # activation in between SS and FF - initializer=None, # initializer on FF - weight_norm=False, # weight normalization on FF - linear=False, - # SSM Kernel arguments - **kernel_args, - ): - super().__init__() - self.h = d_model - self.d = d_state - self.channels = channels - #self.shift = shift - #self.linear = linear - self.out_d = self.h - self.transposed = transposed - self.bidirectional = bidirectional - assert not bidirectional, f"Bidirectional NYI" - self.s4 = SimpleS4(nHippos=d_model, d_state=d_state, - channels=channels, **kernel_args) - # the mapping - # We transpose if it's not in the forward. - nl = NonLinear(self.h, channels=self.channels, ln=ln, # Extra normalization - dropout=dropout, postact=postact, activation=activation, transposed=True, - initializer=initializer, weight_norm=weight_norm) - self.out = nn.Identity() if linear else nl - - def forward(self, u, state=None): - # u: (B H L) if self.transposed else (B L H) - if not self.transposed: u = u.transpose(-1, -2) - # We only pass BHL, and it is as if transposed is True. - ret = self.out(self.s4(u)) - if not self.transposed: ret = ret.transpose(-1, -2) - return ret, state - - @property - def d_state(self): return self.h * self.d - - @property - def d_output(self): return self.out_d \ No newline at end of file diff --git a/src/models/sequence/ss/s4_simple/README.md b/src/models/sequence/ss/s4_simple/README.md new file mode 100644 index 00000000..8663409d --- /dev/null +++ b/src/models/sequence/ss/s4_simple/README.md @@ -0,0 +1,20 @@ +# S4 Simple + +This is the code for the blog post [Simplifying S4](https://hazyresearch.stanford.edu/blog/2022-06-11-simplifying-s4). +We present a simplified version of the S4 kernel with diagonal matrices and fewer learnable parameters. + +You can find the kernel in the `s4_simple.py` file. + +Running the code is as simple as (from the root directory of this repo): +``` +python -m train experiment=s4-simple-cifar wandb=null +``` +(You can remove `wandb=null` if you want to log the run to WandB.) +This code should reach 83-84% val accuracy on CIFAR10. + +By default, the kernel ignores the initial state (fusing `b` and `c`), and only trains the `a` parameters (leaving `theta` fixed to the initialization). +You can play with those parameters in the training run: +* Adding `use_initial=true` will add a learnable initial state, and learn the `b` and `c` parameters separately. +* Setting `learn_theta=true` will make the `theta` parameters learnable (we usually see a decrease in performance of about 3 points from this). +* Setting `leran_a=false` will make the `a` parameters not learnable. We don't see much of a performance degradation on CIFAR in this case, which speaks to the utility of the Chebyshev initialization! +* Setting `zero_order_hold=false` will switch from Zero-Order Hold to left-end-point quadrature. Additionally setting `trap_rule=true` will switch to the trapezoid rule (when `zxero_order_hold` is set to `false`). \ No newline at end of file diff --git a/src/models/sequence/ss/s4_simple/s4_simple.py b/src/models/sequence/ss/s4_simple/s4_simple.py new file mode 100644 index 00000000..36bc7dbd --- /dev/null +++ b/src/models/sequence/ss/s4_simple/s4_simple.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +from einops import rearrange, repeat +import opt_einsum as oe +from src.models.sequence.ss.s4_simple.utils import OurModule, fft_conv + +import math + +class SimpleS4(OurModule): + def __init__(self, + nHippos, + d_state =64, + channels=1, + use_initial=True, # Use the initial state? + zero_order_hold=True, # Use zero-order hold approximation + trap_rule=False, + dt_min=0.001, + dt_max=0.1, + lr=None, # Hook to set LR of SSM parameters differently + learn_a=True, + learn_theta=True, + theta_scale=False, + **kernel_args,): # Use the trapezoid rule + super().__init__() + # H is number of hippos + # D is the dimension (also shockingly n other places) + # B is the batch + # L is the length + self.h = nHippos + self.d = d_state // 2 # Adjustment for conjugate pairs + self.channels = channels + self.use_initial = use_initial + self.zero_order_hold = zero_order_hold + # + # Use the trapezoid rule correct or just do zero-order hold. + self.trap_rule = trap_rule + + _fp = (self.channels, self.h, self.d) + + # Chebyshev initialization + h_scale = torch.exp(torch.arange(self.h)/self.h * math.log(dt_max/dt_min)) + angles = torch.arange(self.d)*torch.pi + theta_scale = h_scale if theta_scale else torch.ones(self.h) + theta = oe.contract('c,h,d->chd', torch.ones(self.channels), h_scale, angles) + a = -repeat(h_scale, 'h -> c h d', c=self.channels, d=self.d) + + self.register("theta", theta,learn_theta,lr=lr, wd=None) + self.register("a", a, learn_a,lr=lr, wd=None) + # The other maps + self.D = nn.Parameter(torch.randn(channels, self.h)) + + if use_initial: + self.b = nn.Parameter(torch.randn(*_fp)) + self.c = nn.Parameter(torch.randn(*_fp)) + self.x0 = nn.Parameter(torch.randn(channels, self.h, self.d)) + else: + # This is an optimization that we combine q = c * b + # It's as if we're setting x0 = 0. + self.q = nn.Parameter(torch.randn(*_fp)) + + def zoh_method(self, u): + l = u.size(-1) + T = 1/(l-1) + zk = T*torch.arange(u.size(-1), device=u.device).view(1,1,-1,1) + ls = torch.complex(-self.a.abs(), self.theta) + term_0 = (torch.exp(ls*T) - 1)/ls + base_term = (2*term_0.unsqueeze(2)*torch.exp(ls.unsqueeze(2)* zk)).real + q = self.b*self.c if self.use_initial else self.q + f = (q.unsqueeze(2)*base_term).sum(-1) + y = fft_conv(u,f) + y = y + oe.contract('bhl,ch->bchl', u, self.D) + if self.use_initial: + # This the cosine formula from the note + cos_term = 2*T*torch.exp(-self.a.abs().unsqueeze(2) * zk)*torch.cos( self.theta.unsqueeze(2) * zk) + y = y + (2*(self.c*self.x0).unsqueeze(2)*cos_term).sum(-1) + return rearrange(y, 'b c h l-> b (c h) l') # flatten the channels. + + def quadrature_method(self, u): + # The input is now Batch x Hippos x Length + l = u.size(-1) + T = 1/(l-1) # the step size + zk = T*torch.arange(l, device=u.device).view(1,1,-1,1) + # q and a are both C x H x D + # zk is of length l we want a C x H x L matrix + # From the note, we have + # f[k] = 2 sum_{j=1}^{d} q_j e^{a_j z_k} cos( z_k * theta_j ) + # we construct the body of the sum + base_term = 2*T*torch.exp(-self.a.abs().unsqueeze(2) * zk)*torch.cos( self.theta.unsqueeze(2) * zk) + q = self.b*self.c if self.use_initial else self.q + f = (q.unsqueeze(2)*base_term).sum(-1) + + # after summing f it is now an C H L matrix + g = u # this is a B H L matrix + # we want to convolve on L and produce a B H C L + y = fft_conv(g,f) + if self.trap_rule: + y = y - T*(oe.contract('ch,bhl -> bchl', f[:,:,0], g) + oe.contract('chl,bh -> bchl', f, g[:,:,0]))/2 + + # Add in the skip connection with per-channel D matrix + y = y + oe.contract('bhl,ch->bchl', u, self.D) + # Add back the initial state + if self.use_initial: + y = y + (2*(self.c*self.x0).unsqueeze(2)*base_term).sum(-1) + return rearrange(y, 'b c h l-> b (c h) l') # flatten the channels. + + def forward(self, u): + return self.zoh_method(u) if self.zero_order_hold else self.quadrature_method(u) \ No newline at end of file diff --git a/src/models/sequence/ss/s4_simple/s4_wrapper.py b/src/models/sequence/ss/s4_simple/s4_wrapper.py new file mode 100644 index 00000000..ab27d0d3 --- /dev/null +++ b/src/models/sequence/ss/s4_simple/s4_wrapper.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +from src.models.nn import Activation, LinearActivation +from src.models.sequence.ss.s4_simple.utils import DropoutNd +from src.models.sequence.ss.s4_simple.s4_simple import SimpleS4 + +# Below here are standard wrapper classes to handle +# (1) Non-linearity +# (2) Integration with the Hippo Code base +class NonLinear(nn.Module): + def __init__(self, h, channels, + ln=False, # Extra normalization + transposed=True, + dropout=0.0, + postact=None, # activation after FF + activation='gelu', # activation in between SS and FF + initializer=None, # initializer on FF + weight_norm=False, # weight normalization on FF + ): + super().__init__() + dropout_fn = DropoutNd # nn.Dropout2d bugged in PyTorch 1.11 + dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + #norm = Normalization(h*channels, transposed=transposed) if ln else nn.Identity() + + activation_fn = Activation(activation) + + output_linear = LinearActivation( + h*channels, + h, + transposed=transposed, + initializer=initializer, + activation=postact, + activate=True, + weight_norm=weight_norm, + ) + #self.f = nn.Sequential(activation_fn, dropout, norm, output_linear) + self.f = nn.Sequential(activation_fn, dropout, output_linear) + def forward(self,x): # Always (B H L) + return self.f(x) + +class SimpleS4Wrapper(nn.Module): + def __init__( + self, + d_model, + d_state=64, + channels=1, + bidirectional=False, + dropout=0.0, + transposed=True, # axis ordering (B, L, D) or (B, D, L) + ln=True, # IGNORED: Extra normalization + postact=None, # activation after FF + activation='gelu', # activation in between SS and FF + initializer=None, # initializer on FF + weight_norm=False, # weight normalization on FF + linear=False, + # SSM Kernel arguments + **kernel_args, + ): + super().__init__() + self.h = d_model + self.d = d_state + self.channels = channels + #self.shift = shift + #self.linear = linear + self.out_d = self.h + self.transposed = transposed + self.bidirectional = bidirectional + assert not bidirectional, f"Bidirectional NYI" + self.s4 = SimpleS4(nHippos=d_model, d_state=d_state, + channels=channels, **kernel_args) + # the mapping + # We transpose if it's not in the forward. + nl = NonLinear(self.h, channels=self.channels, ln=ln, # Extra normalization + dropout=dropout, postact=postact, activation=activation, transposed=True, + initializer=initializer, weight_norm=weight_norm) + self.out = nn.Identity() if linear else nl + + def forward(self, u, state=None): + # u: (B H L) if self.transposed else (B L H) + if not self.transposed: u = u.transpose(-1, -2) + # We only pass BHL, and it is as if transposed is True. + ret = self.out(self.s4(u)) + if not self.transposed: ret = ret.transpose(-1, -2) + return ret, state + + @property + def d_state(self): return self.h * self.d + + @property + def d_output(self): return self.out_d \ No newline at end of file diff --git a/src/models/sequence/ss/s4_simple/utils.py b/src/models/sequence/ss/s4_simple/utils.py new file mode 100644 index 00000000..4dc9ef08 --- /dev/null +++ b/src/models/sequence/ss/s4_simple/utils.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import opt_einsum as oe + +# Replacement for Dropout in PyTorch 1.11.0 +class DropoutNd(nn.Module): + def __init__(self, p: float = 0.5, tie=True): + """ tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) + For some reason tie=False is dog slow, prob something wrong with torch.distribution + """ + super().__init__() + if p < 0 or p >= 1: + raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p)) + self.p = p + self.tie = tie + self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) + + def forward(self, X): + """ X: (batch, dim, lengths...) """ + if self.training: + # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) + mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape + # mask = self.binomial.sample(mask_shape) + mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p + return X * mask * (1.0/(1-self.p)) + return X + +class OurModule(nn.Module): + def __init__(self): super().__init__() + + def register(self, name, tensor, trainable=False, lr=None, wd=None): + """Utility method: register a tensor as a buffer or trainable parameter""" + + if trainable: + self.register_parameter(name, nn.Parameter(tensor)) + else: + self.register_buffer(name, tensor) + + optim = {} + if trainable and lr is not None: optim["lr"] = lr + if trainable and wd is not None: optim["weight_decay"] = wd + if len(optim) > 0: setattr(getattr(self, name), "_optim", optim) + +# +# This is intended to match np.convolve(x,w)[:len(w)] +# That is, (u \ast v)[k] = sum_{j} u[k-j]v[j] +# Here y = (u \ask v) on return. +# We assume the inputs are: +# u (B H L) +# v (C H L) +# and we want to produce y that is (B C H L) +# +def fft_conv(u,v): + L = u.shape[-1] + u_f = torch.fft.rfft(u, n=2*L) # (B H L) + v_f = torch.fft.rfft(v, n=2*L) # (C H L) + + y_f = oe.contract('bhl,chl->bchl', u_f, v_f) + y = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L) + return y \ No newline at end of file diff --git a/src/utils/registry.py b/src/utils/registry.py index 32a81f68..974a560d 100644 --- a/src/utils/registry.py +++ b/src/utils/registry.py @@ -72,7 +72,7 @@ "attsimp": "src.models.sequence.mha.AttentionSimple", "performer": "src.models.sequence.attention.linear.Performer", "s4_2dconv": "src.models.sequence.ss.s4_2dconv.S42DConv", - "s4_simple": "src.models.sequence.ss.s4_simple.SimpleS4Wrapper", + "s4_simple": "src.models.sequence.ss.s4_simple.s4_wrapper.SimpleS4Wrapper", # 'packedrnn': 'models.sequence.rnns.packedrnn.PackedRNN', } From 40d7dcb47f894a7ab8355cb658eceb0a7d2855a6 Mon Sep 17 00:00:00 2001 From: Dan Fu Date: Fri, 17 Jun 2022 13:34:12 -0700 Subject: [PATCH 3/6] README --- src/models/sequence/ss/s4_simple/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/sequence/ss/s4_simple/README.md b/src/models/sequence/ss/s4_simple/README.md index 8663409d..5c50a5ad 100644 --- a/src/models/sequence/ss/s4_simple/README.md +++ b/src/models/sequence/ss/s4_simple/README.md @@ -13,7 +13,7 @@ python -m train experiment=s4-simple-cifar wandb=null This code should reach 83-84% val accuracy on CIFAR10. By default, the kernel ignores the initial state (fusing `b` and `c`), and only trains the `a` parameters (leaving `theta` fixed to the initialization). -You can play with those parameters in the training run: +You can play with these parameters in the training run: * Adding `use_initial=true` will add a learnable initial state, and learn the `b` and `c` parameters separately. * Setting `learn_theta=true` will make the `theta` parameters learnable (we usually see a decrease in performance of about 3 points from this). * Setting `leran_a=false` will make the `a` parameters not learnable. We don't see much of a performance degradation on CIFAR in this case, which speaks to the utility of the Chebyshev initialization! From 3cdf355a33ff88118c3768ab4b9d331acbb7f2e2 Mon Sep 17 00:00:00 2001 From: Dan Fu Date: Fri, 17 Jun 2022 13:34:20 -0700 Subject: [PATCH 4/6] README --- src/models/sequence/ss/s4_simple/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/sequence/ss/s4_simple/README.md b/src/models/sequence/ss/s4_simple/README.md index 5c50a5ad..a1c9e861 100644 --- a/src/models/sequence/ss/s4_simple/README.md +++ b/src/models/sequence/ss/s4_simple/README.md @@ -17,4 +17,4 @@ You can play with these parameters in the training run: * Adding `use_initial=true` will add a learnable initial state, and learn the `b` and `c` parameters separately. * Setting `learn_theta=true` will make the `theta` parameters learnable (we usually see a decrease in performance of about 3 points from this). * Setting `leran_a=false` will make the `a` parameters not learnable. We don't see much of a performance degradation on CIFAR in this case, which speaks to the utility of the Chebyshev initialization! -* Setting `zero_order_hold=false` will switch from Zero-Order Hold to left-end-point quadrature. Additionally setting `trap_rule=true` will switch to the trapezoid rule (when `zxero_order_hold` is set to `false`). \ No newline at end of file +* Setting `zero_order_hold=false` will switch from Zero-Order Hold to left-end-point quadrature. Additionally setting `trap_rule=true` will switch to the trapezoid rule (when `zero_order_hold` is set to `false`). \ No newline at end of file From 6ef85f60dc25da8a1fd9399a66a70cc9388d52d3 Mon Sep 17 00:00:00 2001 From: Dan Fu Date: Fri, 17 Jun 2022 13:37:58 -0700 Subject: [PATCH 5/6] Update README --- src/models/sequence/ss/s4_simple/README.md | 13 +++++++++++-- src/models/sequence/ss/s4_simple/s4_wrapper.py | 2 +- src/models/sequence/ss/s4_simple/utils.py | 1 + 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/models/sequence/ss/s4_simple/README.md b/src/models/sequence/ss/s4_simple/README.md index a1c9e861..599fea28 100644 --- a/src/models/sequence/ss/s4_simple/README.md +++ b/src/models/sequence/ss/s4_simple/README.md @@ -5,7 +5,11 @@ We present a simplified version of the S4 kernel with diagonal matrices and fewe You can find the kernel in the `s4_simple.py` file. -Running the code is as simple as (from the root directory of this repo): +## Running the Code + +To run the code, first follow the install instructions of the overall repo. + +Then, running the code is as simple as (from the root directory of this repo): ``` python -m train experiment=s4-simple-cifar wandb=null ``` @@ -17,4 +21,9 @@ You can play with these parameters in the training run: * Adding `use_initial=true` will add a learnable initial state, and learn the `b` and `c` parameters separately. * Setting `learn_theta=true` will make the `theta` parameters learnable (we usually see a decrease in performance of about 3 points from this). * Setting `leran_a=false` will make the `a` parameters not learnable. We don't see much of a performance degradation on CIFAR in this case, which speaks to the utility of the Chebyshev initialization! -* Setting `zero_order_hold=false` will switch from Zero-Order Hold to left-end-point quadrature. Additionally setting `trap_rule=true` will switch to the trapezoid rule (when `zero_order_hold` is set to `false`). \ No newline at end of file +* Setting `zero_order_hold=false` will switch from Zero-Order Hold to left-end-point quadrature. Additionally setting `trap_rule=true` will switch to the trapezoid rule (when `zero_order_hold` is set to `false`). + +## Other Files + +There are two other files in this folder, `s4_wrapper.py` and `utils.py`. +They contain some standard wrapper classes and utils for integrating into the state spaces code base. \ No newline at end of file diff --git a/src/models/sequence/ss/s4_simple/s4_wrapper.py b/src/models/sequence/ss/s4_simple/s4_wrapper.py index ab27d0d3..ec93d1c4 100644 --- a/src/models/sequence/ss/s4_simple/s4_wrapper.py +++ b/src/models/sequence/ss/s4_simple/s4_wrapper.py @@ -6,7 +6,7 @@ # Below here are standard wrapper classes to handle # (1) Non-linearity -# (2) Integration with the Hippo Code base +# (2) Integration with the State Spaces Code base class NonLinear(nn.Module): def __init__(self, h, channels, ln=False, # Extra normalization diff --git a/src/models/sequence/ss/s4_simple/utils.py b/src/models/sequence/ss/s4_simple/utils.py index 4dc9ef08..3cd2fd03 100644 --- a/src/models/sequence/ss/s4_simple/utils.py +++ b/src/models/sequence/ss/s4_simple/utils.py @@ -25,6 +25,7 @@ def forward(self, X): return X * mask * (1.0/(1-self.p)) return X +# Utility class for registering the learning rate in the state spaces repo class OurModule(nn.Module): def __init__(self): super().__init__() From 14ef30ffdbe37c8c0a303983593185c298a2323d Mon Sep 17 00:00:00 2001 From: Michael Zhang Date: Mon, 20 Jun 2022 19:07:40 -0400 Subject: [PATCH 6/6] Fix typo L22: Setting `learn_theta=true` (before was something else) --- src/models/sequence/ss/s4_simple/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/models/sequence/ss/s4_simple/README.md b/src/models/sequence/ss/s4_simple/README.md index 599fea28..2be29750 100644 --- a/src/models/sequence/ss/s4_simple/README.md +++ b/src/models/sequence/ss/s4_simple/README.md @@ -20,10 +20,10 @@ By default, the kernel ignores the initial state (fusing `b` and `c`), and only You can play with these parameters in the training run: * Adding `use_initial=true` will add a learnable initial state, and learn the `b` and `c` parameters separately. * Setting `learn_theta=true` will make the `theta` parameters learnable (we usually see a decrease in performance of about 3 points from this). -* Setting `leran_a=false` will make the `a` parameters not learnable. We don't see much of a performance degradation on CIFAR in this case, which speaks to the utility of the Chebyshev initialization! +* Setting `learn_a=false` will make the `a` parameters not learnable. We don't see much of a performance degradation on CIFAR in this case, which speaks to the utility of the Chebyshev initialization! * Setting `zero_order_hold=false` will switch from Zero-Order Hold to left-end-point quadrature. Additionally setting `trap_rule=true` will switch to the trapezoid rule (when `zero_order_hold` is set to `false`). ## Other Files There are two other files in this folder, `s4_wrapper.py` and `utils.py`. -They contain some standard wrapper classes and utils for integrating into the state spaces code base. \ No newline at end of file +They contain some standard wrapper classes and utils for integrating into the state spaces code base.