import torch import torch.nn as nn import torch.nn.functional as F import math # KANLinear实现 class KANLinear(nn.Module): def __init__( self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KANLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.grid_size = grid_size self.spline_order = spline_order h = (grid_range[1] - grid_range[0]) / grid_size grid = ( ( torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0] ) .expand(in_features, -1) .contiguous() ) self.register_buffer("grid", grid) self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features)) self.spline_weight = nn.Parameter( torch.Tensor(out_features, in_features, grid_size + spline_order) ) if enable_standalone_scale_spline: self.spline_scaler = nn.Parameter( torch.Tensor(out_features, in_features) ) self.scale_noise = scale_noise self.scale_base = scale_base self.scale_spline = scale_spline self.enable_standalone_scale_spline = enable_standalone_scale_spline self.base_activation = base_activation() self.grid_eps = grid_eps self.reset_parameters() def reset_parameters(self): nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) with torch.no_grad(): noise = ( ( torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2 ) * self.scale_noise / self.grid_size ) self.spline_weight.data.copy_( (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff( self.grid.T[self.spline_order : -self.spline_order], noise, ) ) if self.enable_standalone_scale_spline: nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) def b_splines(self, x: torch.Tensor): """ 计算给定输入张量的B样条基函数 """ assert x.dim() == 2 and x.size(1) == self.in_features grid: torch.Tensor = self.grid x = x.unsqueeze(-1) bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) for k in range(1, self.spline_order + 1): bases = ( (x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1] ) + ( (grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:] ) assert bases.size() == ( x.size(0), self.in_features, self.grid_size + self.spline_order, ) return bases.contiguous() def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): """ 计算插值给定点的曲线系数 """ assert x.dim() == 2 and x.size(1) == self.in_features assert y.size() == (x.size(0), self.in_features, self.out_features) A = self.b_splines(x).transpose( 0, 1 ) # (in_features, batch_size, grid_size + spline_order) B = y.transpose(0, 1) # (in_features, batch_size, out_features) solution = torch.linalg.lstsq( A, B ).solution # (in_features, grid_size + spline_order, out_features) result = solution.permute( 2, 0, 1 ) # (out_features, in_features, grid_size + spline_order) assert result.size() == ( self.out_features, self.in_features, self.grid_size + self.spline_order, ) return result.contiguous() @property def scaled_spline_weight(self): return self.spline_weight * ( self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0 ) def forward(self, x: torch.Tensor): assert x.size(-1) == self.in_features original_shape = x.shape x = x.view(-1, self.in_features) base_output = F.linear(self.base_activation(x), self.base_weight) spline_output = F.linear( self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1), ) output = base_output + spline_output output = output.view(*original_shape[:-1], self.out_features) return output @torch.no_grad() def update_grid(self, x: torch.Tensor, margin=0.01): assert x.dim() == 2 and x.size(1) == self.in_features batch = x.size(0) splines = self.b_splines(x) # (batch, in, coeff) splines = splines.permute(1, 0, 2) # (in, batch, coeff) orig_coeff = self.scaled_spline_weight # (out, in, coeff) orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) unreduced_spline_output = unreduced_spline_output.permute( 1, 0, 2 ) # (batch, in, out) # 排序每个通道以收集数据分布 x_sorted = torch.sort(x, dim=0)[0] grid_adaptive = x_sorted[ torch.linspace( 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device ) ] uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size grid_uniform = ( torch.arange( self.grid_size + 1, dtype=torch.float32, device=x.device ).unsqueeze(1) * uniform_step + x_sorted[0] - margin ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive grid = torch.cat( [ grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), ], dim=0, ) self.grid.copy_(grid.T) self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): """ 计算正则化损失 """ l1_fake = self.spline_weight.abs().mean(-1) regularization_loss_activation = l1_fake.sum() p = l1_fake / (regularization_loss_activation + 1e-8) regularization_loss_entropy = -torch.sum(p * torch.log(p + 1e-8)) return ( regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy ) # KAN模型实现 class KAN(nn.Module): def __init__( self, layers_hidden, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KAN, self).__init__() self.grid_size = grid_size self.spline_order = spline_order self.layers = nn.ModuleList() for in_features, out_features in zip(layers_hidden, layers_hidden[1:]): self.layers.append( KANLinear( in_features, out_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) ) def forward(self, x: torch.Tensor, update_grid=False): for layer in self.layers: if update_grid: layer.update_grid(x) x = layer(x) return x def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): return sum( layer.regularization_loss(regularize_activation, regularize_entropy) for layer in self.layers ) # 为时间序列预测创建的KAN模型 class KANForecaster(nn.Module): def __init__( self, input_features, hidden_sizes=[64, 128, 64], output_size=1, grid_size=5, spline_order=3, dropout_rate=0.1, output_sequence_length=1 ): super(KANForecaster, self).__init__() # 输入投影层 self.input_projection = nn.Linear(input_features, hidden_sizes[0]) # KAN层 layers_hidden = [hidden_sizes[0]] + hidden_sizes + [hidden_sizes[-1]] self.kan = KAN( layers_hidden=layers_hidden, grid_size=grid_size, spline_order=spline_order, ) # 输出层 self.output_layer = nn.Linear(hidden_sizes[-1], output_sequence_length) self.dropout = nn.Dropout(dropout_rate) self.output_sequence_length = output_sequence_length def forward(self, x, update_grid=False): # 输入预处理 [batch_size, seq_length, features] batch_size, seq_len, features = x.shape # 展平时间步和批次维度 x_reshaped = x.reshape(-1, features) # 经过输入投影 x = self.input_projection(x_reshaped) x = F.relu(x) x = self.dropout(x) # 通过KAN网络 x = self.kan(x, update_grid) # 重塑回批次和时间步 x = x.view(batch_size, seq_len, -1) # 聚合时间维度(取最后一个时间步) x = x[:, -1, :] # 输出层 x = self.output_layer(x) return x