This is an implementation of a completion engine that parses type safe programs incrementally, guaranteeing that intermediate outputs can be completed to type-safe programs. The completion enginge can be used to constrain the sampling from an LLM model to only type-safe programs. The implementation currently only handles TypeScript.
More details on the properties of the completion engine and supported features can be found in the paper Type-Constrained Code Generation with Language Models.
When set-up correctly, the package can be used to sample type-safe TypeScript programs from a language model. The following will incrementally generate the code for a TypeScript merge sort function, while ensuring that the generated code is type-safe:
from typesafe_llm.sampling import sample_constrained
sample_constrained(
prompt="function merge_sort(x:number[]):number[] {",
max_tokens=100,
device="cuda",
model_name = "google/gemma-2-2b-it",
temperature=0,
do_sample=False,
trace=True,
)
print("Generation completed")
The project contains two main parts:
- The sampling algorithm, which is used to sample type-safe TypeScript programs from a language model.
- The parser, which is used to parse TypeScript programs and check their completability to type-safe programs.
To install the package, we recommend setting up a conda environment using NVIDIA GPUs.
git clone https://github.com/eth-sri/type-constrained-code-generation.git
cd type-constrained-code-generation
conda create -n typesafe_llm python=3.11
conda activate typesafe_llm
# for LLM inference
# set up torch
conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia -y
# install flash-attention
pip install flash-attn==2.7.3 --no-build-isolation
# install package
pip install -e .
If you only want to use the parser and do not want to sample from a language model, you can skip the installation of torch
and flash-attention
.
To sample type-safe TypeScript programs from a language model, you can use the sample_constrained
function from the typesafe_llm.sampling
module.
from typesafe_llm.sampling import sample_constrained
sample = sample_constrained(
prompt="function merge_sort(x:number[]):number[] {",
max_tokens=100,
device="cuda",
model_name = "google/gemma-2-2b-it",
temperature=0.1,
do_sample=True,
)
print(sample)
If GPU is available, set device to "cuda", on MacBook Pro set device to "mps" (when pytorch nightly is installed).
Setting the device to "cpu" always works.
trace
controls a debugging output for live debugging of the generation process.
Set to False for programmatic use.
You can also independently use the parser to parse TypeScript programs and check their completability.
from typesafe_llm.parser.parser_ts import parse_ts_program
states = parse_ts_program("let x:number = 1;x;")
print(list(states))
# only one accepting state
states = parse_ts_program('let x:number = "he')
print(list(states))
# some accepting states, could be saved by y".length
states = parse_ts_program('let x:boolean = 1 < "hey" +')
print(list(states))
# no states, can not turn "hey" + ... into a number, but need number for < operator
states = parse_ts_program('let x:number = 1;let y')
print(list(states))
# two partial states, one where the second variable has name "y" and one where it is not completed yet
To run the tests, you can use the following command:
pytest test
In this section we provide an overview on how to reproduce the experiments presented in our paper.
To reproduce our experiments locally, it is required to have higher-end GPUs, e.g. NVIDIA A100 80GB. The package includes setup scripts for all software requirements using miniconda. Required Hardware / Software:
- x86/64 architecture CPUs
- 80GB GPU VRAM
- CUDA 12.4 or newer
Further the Gemma 2 model family requires accepting an EULA. Please create a huggingface account and visit the model websites to accept the EULA.
- https://huggingface.co/google/gemma-2b-it
- https://huggingface.co/google/gemma-9b-it
- https://huggingface.co/google/gemma-27b-it
You will later be requested for a Hugginface Access Token. Log in with the account with which you accepted the EULA and visit the Access Token page to generate an access token: https://huggingface.co/settings/tokens
Follow the installation instructions to install conda and all dependencies for the experiments:
bash ./setup_conda.sh
# Restart your shell
bash ./setup_env.sh
# NOTE: Some models are guarded on huggingface, so you will need to visit their model page, accept the EULA and enter the huggingface Access Token to your account when prompted. See section "Requirements" for more details.
Important note: Before running the experiments, you need to download the models and datasets used for the experiments.
We provide a script to download the required dataset and models for our experiments. This script must be run before starting the experiments.
You may specify models to download by passing the models
paramater.
python3 experiments/main/download_models.py --models google/gemma-2-2b-it,google/gemma-2-9b-it
To download all required models and datasets, run the following command:
python3 experiments/main/download_models.py
To warm up, we start by reproducing the result for synthesis of the smallest model (Gemma 2 2B) and the MBPP dataset. To avoid using busy GPUs in a shared setting, use command nvidia-smi
to check which GPUs are free. Then specify the IDs of GPUs you want to use by setting the CUDA_VISIBLE_DEVICES
environment variable. If you want to use GPU 0 and 1, run the following command:
CUDA_VISIBLE_DEVICES=0,1 python3 experiments/main/run_experiments_syn_tran.py --models google/gemma-2-2b-it --tasks synth --subsets mbpp
This reproduces the results for Gemma-2B on the synthesis task on MBPP.
The experiment should finish within approximately 4 hours on a single GPU.
The results of the experiment (and all other results) will be stored in experiments/main/results
in an appropriately named jsonl
file. The general schema is experiments/main/results/<subset>_<model>_s=<seed>_t=<temperature>_<task>_<constrained>.jsonl
. In this concrete example experiments/main/results/mbpp_google_gemma-2-2b-it_s=0_t=1_synth_nc.jsonl
and ..._c.jsonl
for the unconstrained and type-constrained variants respectively.
The experiment runs can be cancelled at any time, intermediate results are stored in the
jsonl
files. Upon restarting, the script will automatically pick up the last completed instance and continue from there. It may happen that running tasks daemonize and continue running (checknvidia-smi
). Make sure to kill them manually before restarting.
Our experiment script automatically distributes jobs over indicated GPUs. The script then repeatedly queries whether running jobs are completed and new GPUs are available. You will therefore see something like the following ouput:
+ CUDA_VISIBLE_DEVICES=0 python3 inference_multiple.py --max-tokens 1000 --timeout 300 --model_name google/gemma-2-2b-it --seed 0 --temp 1 --subset mbpp --try_top_k 10000000000000000 --constrained False --output_file 'results/mbpp_google_gemma-2-2b-it_s=0_t=1_synth_nc.jsonl'
+ CUDA_VISIBLE_DEVICES=1 python3 inference_multiple.py --max-tokens 1000 --timeout 300 --model_name google/gemma-2-2b-it --seed 0 --temp 1 --subset mbpp --try_top_k 10000000000000000 --constrained True --output_file 'results/mbpp_google_gemma-2-2b-it_s=0_t=1_synth_c.jsonl'
Total jobs: 2, Running jobs: 2, Remaining jobs: 0. Waiting for running jobs to finish...
To reproduce other tasks, the following commands reproduce the results for the translation task and the repair task on MBPP, and should take around 4 hours each:
CUDA_VISIBLE_DEVICES=0,1 python3 experiments/main/run_experiments_syn_tran.py --models google/gemma-2-2b-it --tasks translate --subsets mbpp
CUDA_VISIBLE_DEVICES=0,1 python3 experiments/main/run_experiments_repair.py --models google/gemma-2-2b-it --subsets mbpp
Then you can run more experiments for synthesis and translation by providing different models (--models
), tasks (--tasks
), and benchmarks (--subsets
). Remember to use CUDA_VISIBLE_DEVICES
.
Note that a single 80 GB GPU provides sufficient VRAM to host any model used in our experiments.
CUDA_VISIBLE_DEVICES=0 python3 experiments/main/run_experiments_syn_tran.py --models google/gemma-2-2b-it,google/gemma-2-9b-it --tasks synth --subsets mbpp,humaneval
CUDA_VISIBLE_DEVICES=0 python3 experiments/main/run_experiments_syn_tran.py --models Qwen/Qwen2.5-32B-Instruct --tasks translate --subsets mbpp
You can similarly start the repair task:
CUDA_VISIBLE_DEVICES=0 python3 experiments/main/run_experiments_repair.py --models google/gemma-2-2b-it,google/gemma-2-9b-it --subsets mbpp,humaneval
CUDA_VISIBLE_DEVICES=0 python3 experiments/main/run_experiments_repair.py --models Qwen/Qwen2.5-32B-Instruct --subsets mbpp
Below is the list of all options for these parameters. Running all these options will cover all our experiments but can take several days. For the sake of time, reviewers may check a subset that they are interested in.
FLAGS
--models=MODELS
Default: google/gemma-2-2b-it,google/gemma-2-9b-it,google/gemma-2-27b-it,deepseek-ai/deepseek-coder-33b-instruct,codellama/CodeLlama-34b-Instruct-hf,Qwen/Qwen2.5-32B-Instruct
--tasks=TASKS (only for experiments/main/run_experiments_syn_tran.py)
Default: synth,translate
--subsets=SUBSETS
Default: humaneval,mbpp
You can also deep dive into obtaining the list of all available parameters:
python3 experiments/main/run_experiments_syn_tran.py --help
python3 experiments/main/run_experiments_repair.py --help
The runtime of our main experiments depends on the choice of datasets and tasks and the choice of models. Generally, larger datasets and larger models result in longer execution times.
Our benchmark features the MBPP and HumanEval datasets, adapted for three tasks: synthesis, translate, and repair. Taking into account additional instances due to running on several seeds, the experiments can be ordered in increasing order of runtime as: MBPP-repair, HumanEval-repair, MBPP-{synthesis,translate}, and HumanEval-{synthesis,translate}.
Our evaluation further features 6 models, in order of increasing parameter size, Gemma 2 2B, Gemma 2 9B, Gemma 2 27B, Qwen 2.5 32B, DeepSeek Coder 33B, and CodeLlama 34B.
Thus, the quickest experiment is computing the performance of Gemma 2 2B synthesis on MBPP, taking approximately 4h on a single GPU. The longest experiment is computing performance of CodeLlama 34B synthesis on HumanEval.
You can run the following command to produce the figures for the paper. You may run this script with partial results, in which case you will receive a print out of missing results and its positions in the table will be substituted with "-1".
bash experiments/main/figures.sh
The results map to the corresponding figures in the paper as follows:
- Table 2 and 3: all models, all tasks, all datasets, i.e.,
[mbpp|humaneval]_*_s=[0|1|2|3]_t=1_[synth|translate|repair-all]_[c|nc].jsonl
. Vanilla and Syntax can be computed based on non-constrained (nc
) variants. - Table 4: all models, synthesis, all datasets, i.e.,
[mbpp|humaneval]_*_s=[0|1|2|3]_t=1_synth_[c|nc].jsonl
- Figure 8: Gemma 2 2B, synthesis, HumanEval, i.e.,
humaneval_google_gemma-2-2b-it_s=[0|1|2|3]_t=1_synth_[c|nc].jsonl
Since running the entire pipeline takes several days using 8 GPUs, we have included our raw data in the experiments/main/results_paper
directory. You can directly run the figures script without running the experiments for the submitted results like this:
bash experiments/main/figures.sh results_paper
Note: Table 4 is a runtime table. You should expect the runtime per instance to differ based on the CPU and GPU used, however the runtime increase should be consistent with our findings.
The core part of our work is the implementation of a completion engine that incrementally parses type-safe TypeScript programs. The completion engine can then be used to constrain the sampling from an LLM model to only generate type-safe programs.
This project is organized as a Python package.
The relevant code for the implementation of type-constrained decoding and sampling is located in the typesafe_llm
directory.
The experiments are located in the experiments
directory.
We further provide a test suite in the tests
directory.
The usage of the latter two is described above.
In the following sections we describe the structure of the typesafe_llm
package.
The sampling algorithm presented in Section 2.1 of the paper is located in typesafe_llm/sampling.py
.
It uses the transformers
library to infer predictions from a language model, sample from it and, if constraining is enabled, runs a parser in parallel to reject invalid programs (sample_constrained
).
The prefix automaton is defined in typesafe_llm/automata/parser_base.py
.
The automaton is implicitely defined by defining the transition function and acceptance status in each state, subclassing from IncrementalParserState
.
A state indicates that it is an accepting state by setting the field accept
to True.
The transition function is invoked by the method parse_char
and returns a list of new states that can be reached by parsing the given character.
The file further contains the definitions of concatenation, union, kleene plus and terminal automata.
The automaton for identifiers (ExistingIdentifierParserState
) is the first automaton defined in typesafe_llm/automata/parser_ts.py
.
The following automata parse literals (LiteralParserState
and its subclasses), including more advanced literals such as regular expressions and template strings.
The automaton for types is defined seperately in typesafe_llm/automata/parser_ts_types.py
.
The expression automaton is defined in typesafe_llm/automata/parser_ts.py
in the class ExpressionParserState
.
It implements the extension logic and the pruning of invalid transitions due to operator precedence and type constraints.
The derivability algorithm is implemented for each state individually in the method derivable
. It determines the directly derivable types and call the reachability algorithm with them.
The type reachability algorithm is implemented in typesafe_llm/parser/types_ts.py
in the method reachable
, leveraging _reachable_bfs
- a straightforward breadth-first search translation of the presented reachability algorithm.
The automaton for statements is defined in typesafe_llm/automata/parser_ts.py
in the class StatementParserState
.
It handles the constraining for valid return types.
The automaton for the entire program is defined in typesafe_llm/automata/parser_ts.py
in the class ProgramParserState
.
No. The problem with traditional compilers is that they only provide feedback on a completed program. Meanwhile, to guide the LLM during generation effectively, we need feedback on the partially generated programs. Therefore compiler can not be reused for type-constrained decoding.
No. These systems are designed to aid humans during development and not for reliable incremental parsing. As such, while LSPs and other systems are helpful and may be able to handle some partial programs, they usually do not guarantee being able to handle arbitrary partial programs. For example, LSPs are useful for providing possible members of objects or parameter types for calls, and have been used for this purpose [1,2], however, they can not always handle partial syntax trees, may not provide help when they fail to derive an object's type, and can not predict whether a partial expression can be completed into the required type of the current context. To reliably provide steering and constraints for all partial programs, we had to build our custom incremental parser.
[1] Agrawal et. al., "Monitor-Guided Decoding of Code LMs with Static Analysis of Repository Context", NeurIPS 2023 (link)
[2] Gvero et. al., "Complete Completion using Types and Weights", ACM Sigplan 2013 (link)
No. As far as we know, such a constraining algorithm has to be implemented manually for every language. As such, we are not aware of any implementations of our method for other languages yet (as of June 2025).