Compare commits

..

1 Commits

Author SHA1 Message Date
Dr.Lt.Data
fe3a688a4a
Merge 1670612b90 into c515bdf371 2025-01-07 16:35:51 -05:00

View File

@ -227,9 +227,8 @@ class T5(torch.nn.Module):
super().__init__() super().__init__()
self.num_layers = config_dict["num_layers"] self.num_layers = config_dict["num_layers"]
model_dim = config_dict["d_model"] model_dim = config_dict["d_model"]
inner_dim = config_dict["d_kv"] * config_dict["num_heads"]
self.encoder = T5Stack(self.num_layers, model_dim, inner_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations) self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
self.dtype = dtype self.dtype = dtype
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype) self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)