Skip to content

Commit b6c96b7

Browse files
authored
Merge pull request thuml#316 from ChuckTG/feature-add-model-FreTS
Add FreTS model
2 parents c2f53c2 + be7b387 commit b6c96b7

File tree

3 files changed

+121
-2
lines changed

3 files changed

+121
-2
lines changed

exp/exp_basic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \
44
Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \
5-
Koopa, TiDE
5+
Koopa, TiDE, FreTS
66

77

88
class Exp_Basic(object):
@@ -27,6 +27,7 @@ def __init__(self, args):
2727
'iTransformer': iTransformer,
2828
'Koopa': Koopa,
2929
'TiDE': TiDE,
30+
'FreTS': FreTS
3031
}
3132
self.device = self._acquire_device()
3233
self.model = self._build_model().to(self.device)

models/FreTS.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import numpy as np
5+
6+
class Model(nn.Module):
7+
"""
8+
Paper link: https://arxiv.org/pdf/2311.06184.pdf
9+
"""
10+
def __init__(self, configs):
11+
super(Model, self).__init__()
12+
self.task_name = configs.task_name
13+
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
14+
self.pred_len = configs.seq_len
15+
else:
16+
self.pred_len = configs.pred_len
17+
self.embed_size = 128 #embed_size
18+
self.hidden_size = 256 #hidden_size
19+
self.pred_len = configs.pred_len
20+
self.feature_size = configs.enc_in #channels
21+
self.seq_len = configs.seq_len
22+
self.channel_independence = configs.channel_independence
23+
self.sparsity_threshold = 0.01
24+
self.scale = 0.02
25+
self.embeddings = nn.Parameter(torch.randn(1, self.embed_size))
26+
self.r1 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size))
27+
self.i1 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size))
28+
self.rb1 = nn.Parameter(self.scale * torch.randn(self.embed_size))
29+
self.ib1 = nn.Parameter(self.scale * torch.randn(self.embed_size))
30+
self.r2 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size))
31+
self.i2 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size))
32+
self.rb2 = nn.Parameter(self.scale * torch.randn(self.embed_size))
33+
self.ib2 = nn.Parameter(self.scale * torch.randn(self.embed_size))
34+
35+
self.fc = nn.Sequential(
36+
nn.Linear(self.seq_len * self.embed_size, self.hidden_size),
37+
nn.LeakyReLU(),
38+
nn.Linear(self.hidden_size, self.pred_len)
39+
)
40+
41+
# dimension extension
42+
def tokenEmb(self, x):
43+
# x: [Batch, Input length, Channel]
44+
x = x.permute(0, 2, 1)
45+
x = x.unsqueeze(3)
46+
# N*T*1 x 1*D = N*T*D
47+
y = self.embeddings
48+
return x * y
49+
50+
# frequency temporal learner
51+
def MLP_temporal(self, x, B, N, L):
52+
# [B, N, T, D]
53+
x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on L dimension
54+
y = self.FreMLP(B, N, L, x, self.r2, self.i2, self.rb2, self.ib2)
55+
x = torch.fft.irfft(y, n=self.seq_len, dim=2, norm="ortho")
56+
return x
57+
58+
# frequency channel learner
59+
def MLP_channel(self, x, B, N, L):
60+
# [B, N, T, D]
61+
x = x.permute(0, 2, 1, 3)
62+
# [B, T, N, D]
63+
x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on N dimension
64+
y = self.FreMLP(B, L, N, x, self.r1, self.i1, self.rb1, self.ib1)
65+
x = torch.fft.irfft(y, n=self.feature_size, dim=2, norm="ortho")
66+
x = x.permute(0, 2, 1, 3)
67+
# [B, N, T, D]
68+
return x
69+
70+
# frequency-domain MLPs
71+
# dimension: FFT along the dimension, r: the real part of weights, i: the imaginary part of weights
72+
# rb: the real part of bias, ib: the imaginary part of bias
73+
def FreMLP(self, B, nd, dimension, x, r, i, rb, ib):
74+
o1_real = torch.zeros([B, nd, dimension // 2 + 1, self.embed_size],
75+
device=x.device)
76+
o1_imag = torch.zeros([B, nd, dimension // 2 + 1, self.embed_size],
77+
device=x.device)
78+
79+
o1_real = F.relu(
80+
torch.einsum('bijd,dd->bijd', x.real, r) - \
81+
torch.einsum('bijd,dd->bijd', x.imag, i) + \
82+
rb
83+
)
84+
85+
o1_imag = F.relu(
86+
torch.einsum('bijd,dd->bijd', x.imag, r) + \
87+
torch.einsum('bijd,dd->bijd', x.real, i) + \
88+
ib
89+
)
90+
91+
y = torch.stack([o1_real, o1_imag], dim=-1)
92+
y = F.softshrink(y, lambd=self.sparsity_threshold)
93+
y = torch.view_as_complex(y)
94+
return y
95+
96+
def forecast(self, x_enc):
97+
# x: [Batch, Input length, Channel]
98+
B, T, N = x_enc.shape
99+
# embedding x: [B, N, T, D]
100+
x = self.tokenEmb(x_enc)
101+
bias = x
102+
# [B, N, T, D]
103+
if self.channel_independence == '1':
104+
x = self.MLP_channel(x, B, N, T)
105+
# [B, N, T, D]
106+
x = self.MLP_temporal(x, B, N, T)
107+
x = x + bias
108+
x = self.fc(x.reshape(B, N, -1)).permute(0, 2, 1)
109+
return x
110+
111+
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
112+
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
113+
dec_out = self.forecast(x_enc)
114+
return dec_out[:, -self.pred_len:, :] # [B, L, D]
115+
else:
116+
raise ValueError('Only forecast tasks implemented yet')
117+

run.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@
7171
help='time features encoding, options:[timeF, fixed, learned]')
7272
parser.add_argument('--activation', type=str, default='gelu', help='activation')
7373
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
74-
74+
parser.add_argument('--channel_independence', type=int, default=0,
75+
help='1: channel dependence 0: channel independence for FreTS model')
7576
# optimization
7677
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
7778
parser.add_argument('--itr', type=int, default=1, help='experiments times')

0 commit comments

Comments
 (0)