MaximeMuhlethaler commited on
Commit
328bc4e
·
verified ·
1 Parent(s): b1f8d8d

Chess Challenge submission by MaximeMuhlethaler

Browse files
Files changed (2) hide show
  1. model.py +5 -10
  2. tokenizer.py +1 -4
model.py CHANGED
@@ -24,7 +24,6 @@ class ChessConfig(PretrainedConfig):
24
  dropout=0.1,
25
  layer_norm_epsilon=1e-5,
26
  tie_weights=True,
27
- # Valeurs par défaut strictes
28
  pad_token_id=0,
29
  bos_token_id=1,
30
  eos_token_id=2,
@@ -41,7 +40,7 @@ class ChessConfig(PretrainedConfig):
41
  self.layer_norm_epsilon = layer_norm_epsilon
42
  self.tie_weights = tie_weights
43
 
44
- # On passe les IDs vitaux à kwargs pour le parent
45
  kwargs["pad_token_id"] = pad_token_id
46
  kwargs["bos_token_id"] = bos_token_id
47
  kwargs["eos_token_id"] = eos_token_id
@@ -118,7 +117,7 @@ class ChessForCausalLM(PreTrainedModel):
118
  def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
119
 
120
  def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, return_dict=None, **kwargs):
121
- # 1. FIX TYPE RETOUR
122
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
123
  if return_dict is None: return_dict = True
124
 
@@ -132,17 +131,13 @@ class ChessForCausalLM(PreTrainedModel):
132
  x = self.ln_f(x)
133
  logits = self.lm_head(x)
134
 
135
- # ---------------------------------------------------------
136
- # 2. PATCH NUCLÉAIRE : On bannit 0, 1, 2, 3 en dur
137
- # ---------------------------------------------------------
138
  if labels is None:
139
- # PAD=0, BOS=1, EOS=2, UNK=3 (Les standards de ton tokenizer)
140
  nuclear_bad_ids = [0, 1, 2, 3]
141
 
142
- # On met -infini (impossible à choisir)
143
- # Le slicing [:, :, ids] couvre tout le batch et toute la séquence
144
  logits[:, :, nuclear_bad_ids] = float("-inf")
145
- # ---------------------------------------------------------
146
 
147
  loss = None
148
  if labels is not None:
 
24
  dropout=0.1,
25
  layer_norm_epsilon=1e-5,
26
  tie_weights=True,
 
27
  pad_token_id=0,
28
  bos_token_id=1,
29
  eos_token_id=2,
 
40
  self.layer_norm_epsilon = layer_norm_epsilon
41
  self.tie_weights = tie_weights
42
 
43
+
44
  kwargs["pad_token_id"] = pad_token_id
45
  kwargs["bos_token_id"] = bos_token_id
46
  kwargs["eos_token_id"] = eos_token_id
 
117
  def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
118
 
119
  def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, return_dict=None, **kwargs):
120
+
121
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
122
  if return_dict is None: return_dict = True
123
 
 
131
  x = self.ln_f(x)
132
  logits = self.lm_head(x)
133
 
134
+
 
 
135
  if labels is None:
136
+
137
  nuclear_bad_ids = [0, 1, 2, 3]
138
 
 
 
139
  logits[:, :, nuclear_bad_ids] = float("-inf")
140
+
141
 
142
  loss = None
143
  if labels is not None:
tokenizer.py CHANGED
@@ -108,17 +108,14 @@ class ChessTokenizer(PreTrainedTokenizer):
108
  from datasets import load_dataset
109
  from collections import Counter
110
 
111
- # On charge en streaming pour aller vite
112
  ds = load_dataset(dataset_name, split="train", streaming=True)
113
- ds = ds.take(50000) # 50k parties suffisent pour voir tous les coups possibles
114
 
115
  counter = Counter()
116
  for ex in ds:
117
- # On normalise avant de compter !
118
  moves = [normalize_move(t) for t in ex["text"].split()]
119
  counter.update(moves)
120
 
121
- # On garde les tokens spéciaux + les N plus fréquents
122
  special = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
123
  most_common = counter.most_common(max_vocab_size - len(special))
124
 
 
108
  from datasets import load_dataset
109
  from collections import Counter
110
 
 
111
  ds = load_dataset(dataset_name, split="train", streaming=True)
112
+ ds = ds.take(50000)
113
 
114
  counter = Counter()
115
  for ex in ds:
 
116
  moves = [normalize_move(t) for t in ex["text"].split()]
117
  counter.update(moves)
118
 
 
119
  special = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
120
  most_common = counter.most_common(max_vocab_size - len(special))
121