Papers Explained 290: rStar-Math

Ritvik Rastogi
7 min readJan 17, 2025

--

rStar-Math demonstrates that small language models (SLMs) can rival or even surpass the math reasoning capability of OpenAI o1, without distillation from superior models. This is achieved by exercising “deep thinking” through Monte Carlo Tree Search (MCTS), where a math policy SLM performs test-time search guided by an SLM-based process reward model.

Design Choices

The aim is to train a math policy SLM and a process reward model (PRM), and integrate both within Monte Carlo Tree Search (MCTS) for System 2 deep thinking. MCTS is chosen for two key reasons.

  1. It breaks down complex math problems into simpler single-step generation tasks, reducing the difficulty for the policy SLM compared to other System 2 methods like Best-of-N or self-consistency, which require generating full solutions in one inference.
  2. The step-by-step generation in MCTS naturally yields step-level training data for both models. Standard MCTS rollout automatically assigns Q-value to each step based on its contribution to the final correct answer, obviating the need for human-generated step-level annotations for process reward model training.

Step-by-Step Verified Reasoning Trajectory

The method generates step-by-step verified reasoning trajectories with per-step Q-value annotations. Given a problem x and a policy model M, the standard MCTS is run to incrementally construct a search tree for step-by-step solution exploration. The root node represents question x, while child nodes correspond to intermediate steps s generated by M. A root-to-leaf path ending at terminal node sd forms a trajectory t = x ⊕ s1 ⊕ s2 ⊕ … ⊕ sd, with each step s_i assigned a Q-value Q(si). From the search tree T, solution trajectories T = {t1, t2, …, tn}(n ≥ 1) are extracted. The goal is to select high-quality trajectories from T to construct the training set. For this purpose, a code-augmented CoT synthesis method is introduced to filter out low-quality generations and perform extensive rollouts to improve the reliability of Q-value accuracy.

LLMs can hallucinate, producing incorrect steps that coincidentally lead to the correct answer. These flawed steps are hard to identify, hence, a code-augmented Chain-of-Thought approach. The policy model generates both a natural language CoT step and corresponding Python code. The NL CoT is embedded as a comment in the Python code.

An example of Code-augmented CoT.

Only generations where the Python code executes successfully are kept. Each generated step s(i,j) is combined with the code from all preceding steps (s1 ⊕ s2 ⊕ … ⊕ s(i-1) ⊕ s(i,j)). If the combined code runs without error, the step is considered valid.

Valid steps are scored using a Process Preference Model (PPM) which assigns a Q-value q(s). The Upper Confidence bounds for Trees (UCT) algorithm then selects the best candidate step using the formula:

where N(s) denotes the number of visits to node s, and Nparent(s) is the visit count of s’s parent node. The predicted reward q(s) is provided by the PPM and will be updated through back-propagation. c is a constant that balances exploitation and exploration.

Accurate Q-values are crucial for guiding MCTS towards correct solutions. Extensive rollouts refine Q-value estimates, similar to how repeated gameplay improves move evaluation in Go. Two self-annotation methods are used:

  • Terminal-Guided Annotation (Rounds 1 & 2): Used when the PPM is unavailable or unreliable. Each step’s Q-value is updated based on whether the final answer is correct. q(si)k = q(si)k-1 + q(sd)k, where k is the rollout number. q(sd) is 1 for correct final answers and -1 for incorrect ones. Initial Q-value q(si)0 is 0.
  • PRM-Augmented Annotation (Rounds 3 & 4): Uses the PPM to provide an initial Q-value for each step: q(si)0 = PPM(x ⊕ s1 ⊕ s2 ⊕ … ⊕ si). This initial value is then updated through MCTS back-propagation using the same formula as terminal-guided annotation. The PPM is not used to score the terminal node sd; instead, ground truth labels are used for more accurate scoring.

Process Preference Model

Existing methods rely on precise numerical scores for each step, derived either from human annotation or MCTS Q-values. These scores are used as training targets for the reward model, minimizing the difference between predicted and labeled scores. However, accurately assigning precise numerical scores to individual steps is extremely difficult, even for experts. Distinguishing between slightly better or worse steps is challenging and leads to noisy training data.

The Preference Pairs method bypasses the need for precise scores by focusing on relative preferences between steps. It constructs positive-negative preference pairs from the MCTS search tree, using Q-values to guide the selection.

Constructing Preference Pairs:

  • For intermediate steps:
  1. Two steps with the highest Q-values leading to a correct final answer are selected as positive examples.
  2. Two steps with the lowest Q-values leading to an incorrect final answer are selected as negative examples.
  3. Crucially, the positive and negative pairs share the same preceding steps (i.e., they branch from the same point in the reasoning trajectory).
  • For the final answer step:
  1. Since identical reasoning trajectories rarely lead to different final answers, the restriction of shared preceding steps is relaxed.
  2. Two complete trajectories with the highest average Q-values leading to correct answers are selected as positive examples.
  3. Two complete trajectories with the lowest average Q-values leading to incorrect answers are selected as negative examples.

Self-Evolved Deep Thinking

Training with Step-by-Step Verified Reasoning Trajectory

Data Collection and Preparation:

  • Math Problems: A dataset of 747k math word problems with ground-truth answers is compiled, primarily from NuminaMath (competition-level problems) and MetaMath. GPT-4 is used to synthesize additional competition-level problems based on MATH and AMC-AIME datasets. GPT-4 generates 10 solutions per synthesized problem, keeping only those with at least 3 consistent solutions to filter out unsolvable or incorrectly labeled problems.
  • Reasoning Trajectories: Instead of using existing solutions, MCTS is used to generate higher-quality, step-by-step reasoning trajectories with annotated Q-values. 16 rollouts are performed per problem in each self-evolution round. Problems are categorized by difficulty (easy, medium, hard) based on the success rate of these rollouts. Hard problems with no correct trajectories undergo additional MCTS rollouts (16 more).

Model Training:

  • Policy SLM (Supervised Fine-Tuning): The policy SLM is fine-tuned using the top-2 trajectories with the highest average Q-values that lead to correct answers. This focuses the training on high-quality reasoning steps.
  • PPM (Process Preference Model Training): The PPM is initialized from the fine-tuned policy model and trained using a pairwise ranking loss. Problems with mixed correct/incorrect trajectories are used. Two positive (highest Q-value, correct answer) and two negative (lowest Q-value, incorrect answer) examples are selected for each step to form preference pairs for training.

Recipe for Self-Evolution

  • Round 1 (Bootstrapping): DeepSeek-Coder-V2-Instruct (236B) is used for initial data generation. Terminal-guided annotation is used for Q-values due to the lack of a reliable PPM. 8 MCTS rollouts are performed per problem for efficiency. This round focuses on creating an initial strong policy SLM (SLM-r1). PPM-r1 is also trained, but is less effective due to limited rollouts.
  • Round 2 (Reliable PPM): SLM-r1 is used. 16 MCTS rollouts are performed per problem for more accurate Q-values. This round focuses on training a reliable PPM (PPM-r2). SLM-r2 is also trained.
  • Round 3 (PPM-Augmented MCTS): PPM-r2 is used to guide MCTS, significantly improving trajectory quality and expanding the training set with more complex problems. SLM-r3 and PPM-r3 are trained.
  • Round 4 (Solving Challenging Problems): Focuses on solving remaining difficult problems, particularly Olympiad-level. Additional MCTS rollouts (up to 128) and multiple tree expansions with different random seeds are used for unsolved problems after 16 initial rollouts. This increases Olympiad-level problem coverage to 80.58%. The self-evolution process stops after this round because the remaining unsolved problems are deemed to be of low quality (incorrectly labeled).

Evaluation

rStar-Math, is applied to several LLMs of varying sizes and its performance is evaluated on a diverse set of mathematical benchmarks and compared against several System 1 and System 2 baselines, including state-of-the-art LLMs like GPT-4, Claude, and OpenAI models, as well as open-sourced reasoning models. The evaluation metric used is Pass@1 accuracy.

The results of rStar-Math and other frontier LLMs on the most challenging math benchmarks.
  • rStar-Math significantly improves the mathematical reasoning performance of LLMs, achieving performance comparable to or exceeding much larger models like OpenAI’s o1, even with smaller model sizes (1.5B-7B).
  • rStar-Math outperforms state-of-the-art System 2 baselines, even with a smaller reward model. It consistently improves the accuracy of all base models, surpassing even Best-of-N baselines that use a 10x larger reward model.
  • rStar-Math demonstrates strong generalizability across diverse and challenging math benchmarks, including Olympiad Bench, College Math, and Gaokao, achieving state-of-the-art results.
Reasoning performance under scaling up the test-time compute.
  • Increasing test-time computation (number of trajectories) generally improves performance, although the degree of improvement varies across benchmarks. rStar-Math with just 4 trajectories significantly outperforms Best-of-N baselines.

Paper

rStar-Math: Small LLMs Can Master Math Reasoning with Self-Evolved Deep Thinking 2501.04519

Hungry for more insights?

Don’t miss out on exploring other fascinating threads in this series. Simply click here and uncover the state-of-the-art research!

Do Subscribe for weekly updates!!

--

--

Ritvik Rastogi
Ritvik Rastogi

Written by Ritvik Rastogi

Data Scientist, 2x Kaggle Expert

No responses yet