import torch import torch.nn as nn import torch.nn.functional as F import math class OptimizedKANLinear(nn.Module): def __init__( self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(OptimizedKANLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.grid_size = grid_size self.spline_order = spline_order h = (grid_range[1] - grid_range[0]) / grid_size grid = ( ( torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0] ) .expand(in_features, -1) .contiguous() ) self.register_buffer("grid", grid) self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features)) self.spline_weight = nn.Parameter( torch.Tensor(out_features, in_features, grid_size + spline_order) ) if enable_standalone_scale_spline: self.spline_scaler = nn.Parameter( torch.Tensor(out_features, in_features) ) self.scale_noise = scale_noise self.scale_base = scale_base self.scale_spline = scale_spline self.enable_standalone_scale_spline = enable_standalone_scale_spline self.base_activation = base_activation() self.grid_eps = grid_eps self.reset_parameters() def reset_parameters(self): nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) with torch.no_grad(): noise = ( ( torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2 ) * self.scale_noise / self.grid_size ) self.spline_weight.data.copy_( (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff( self.grid.T[self.spline_order : -self.spline_order], noise, ) ) if self.enable_standalone_scale_spline: nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) def b_splines(self, x: torch.Tensor): """ 计算给定输入张量的B样条基函数 - 内存优化版本 """ # 确保输入是2D张量 if x.dim() > 2: original_shape = x.shape x = x.reshape(-1, self.in_features) else: original_shape = None assert x.size(-1) == self.in_features grid: torch.Tensor = self.grid x = x.unsqueeze(-1) bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) for k in range(1, self.spline_order + 1): bases = ( (x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1] ) + ( (grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:] ) assert bases.size() == ( x.size(0), self.in_features, self.grid_size + self.spline_order, ) return bases.contiguous() def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): """ 计算插值给定点的曲线系数 - 内存优化版本 """ assert x.dim() == 2 and x.size(1) == self.in_features assert y.size() == (x.size(0), self.in_features, self.out_features) A = self.b_splines(x).transpose( 0, 1 ) # (in_features, batch_size, grid_size + spline_order) B = y.transpose(0, 1) # (in_features, batch_size, out_features) # 使用内存优化的方式计算最小二乘解 solution = torch.linalg.lstsq( A, B ).solution # (in_features, grid_size + spline_order, out_features) result = solution.permute( 2, 0, 1 ) # (out_features, in_features, grid_size + spline_order) assert result.size() == ( self.out_features, self.in_features, self.grid_size + self.spline_order, ) return result.contiguous() @property def scaled_spline_weight(self): return self.spline_weight * ( self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0 ) def forward(self, x: torch.Tensor): """ 前向传播 - 内存优化版本 """ # 保存原始形状并将输入重塑为2D if x.dim() > 2: original_shape = x.shape x = x.reshape(-1, self.in_features) else: original_shape = None assert x.size(-1) == self.in_features # 计算基础激活和B样条基函数(只计算一次) base_activated = self.base_activation(x) splines = self.b_splines(x) # 计算基础输出 base_output = F.linear(base_activated, self.base_weight) # 计算样条输出 - 使用内存优化的方式 spline_output = F.linear( splines.view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1), ) # 合并输出 output = base_output + spline_output # 如果需要,恢复原始形状 if original_shape is not None: output = output.view(*original_shape[:-1], self.out_features) return output @torch.no_grad() def update_grid(self, x: torch.Tensor, margin=0.01): """ 更新网格 - 内存优化版本 """ # 保存原始形状并将输入重塑为2D if x.dim() > 2: original_shape = x.shape x = x.reshape(-1, self.in_features) else: original_shape = None assert x.size(-1) == self.in_features batch = x.size(0) # 使用内存优化的方式计算样条输出 splines = self.b_splines(x) splines = splines.permute(1, 0, 2) # (in, batch, coeff) orig_coeff = self.scaled_spline_weight # (out, in, coeff) orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) # 使用批量矩阵乘法计算未减少的样条输出 unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) unreduced_spline_output = unreduced_spline_output.permute( 1, 0, 2 ) # (batch, in, out) # 对每个通道单独排序以收集数据分布 x_sorted = torch.sort(x, dim=0)[0] grid_adaptive = x_sorted[ torch.linspace( 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device ) ] uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size grid_uniform = ( torch.arange( self.grid_size + 1, dtype=torch.float32, device=x.device ).unsqueeze(1) * uniform_step + x_sorted[0] - margin ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive # 使用torch.cat而非concatenate以提高兼容性 grid = torch.cat( [ grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), ], dim=0, ) self.grid.copy_(grid.T) self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): """ 计算正则化损失 - 内存优化版本 """ # 使用更高效的方式计算L1正则化 l1_fake = self.spline_weight.abs().mean(-1) regularization_loss_activation = l1_fake.sum() # 计算熵正则化 p = l1_fake / (regularization_loss_activation + 1e-8) regularization_loss_entropy = -torch.sum(p * torch.log(p + 1e-8)) return ( regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy ) class OptimizedKAN(nn.Module): def __init__( self, layers_hidden, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(OptimizedKAN, self).__init__() self.grid_size = grid_size self.spline_order = spline_order self.layers = nn.ModuleList() for in_features, out_features in zip(layers_hidden, layers_hidden[1:]): self.layers.append( OptimizedKANLinear( in_features, out_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) ) def forward(self, x: torch.Tensor, update_grid=False): # 保存原始形状以便后续处理 original_shape = x.shape if x.dim() > 2 else None for layer in self.layers: if update_grid: layer.update_grid(x) x = layer(x) return x def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): return sum( layer.regularization_loss(regularize_activation, regularize_entropy) for layer in self.layers ) class OptimizedKANForecaster(nn.Module): def __init__( self, input_features, hidden_sizes=[64, 128, 64], output_size=1, grid_size=5, spline_order=3, dropout_rate=0.1, output_sequence_length=1, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(OptimizedKANForecaster, self).__init__() # 输入投影层 self.input_projection = nn.Linear(input_features, hidden_sizes[0]) # 优化的KAN层 layers_hidden = [hidden_sizes[0]] + hidden_sizes + [hidden_sizes[-1]] self.kan = OptimizedKAN( layers_hidden=layers_hidden, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) # 输出层 self.output_layer = nn.Linear(hidden_sizes[-1], output_sequence_length) self.dropout = nn.Dropout(dropout_rate) self.output_sequence_length = output_sequence_length def forward(self, x, update_grid=False): """ 前向传播 - 内存优化版本 x形状: [batch_size, seq_length, features] """ batch_size, seq_len, features = x.shape # 经过输入投影 x_reshaped = x.reshape(-1, features) x = self.input_projection(x_reshaped) x = F.relu(x) x = self.dropout(x) # 重塑为3D张量以保留批次和时间步信息 x = x.view(batch_size, seq_len, -1) # 通过KAN网络 - 内部会处理维度转换 x = self.kan(x, update_grid) # 聚合时间维度(取最后一个时间步) x = x[:, -1, :] # 输出层 x = self.output_layer(x) return x def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): """ 计算KAN网络的正则化损失 """ return self.kan.regularization_loss( regularize_activation=regularize_activation, regularize_entropy=regularize_entropy )