树扩散
结构
主要代码位于 td/
目录。环境实现在 td/environments
,突变和树路径搜索在 td/samplers/mutator.py
,通用语法实现在 td/grammar.py
。
所有模型权重位于 assets/
目录。
设置
下载模型权重
请从 此链接 下载所有模型权重,并将它们放在此仓库根目录的新建 assets/
文件夹中。
依赖
使用 Python 3.11。
安装依赖:
pip install -r requirements.txt
使用
确保 PYTHONPATH
正确设置:
在 Linux/Mac 上:
export PYTHONPATH=.
在 Windows 上:
set PYTHONPATH=.
首先测试设置:
python scripts/test_setup.py
进行评估:
python scripts/eval_td_search.py --checkpoint_name assets/td_csg2da.pt --ar_checkpoint_name assets/ar_csg2da.pt --problem_filename assets/csg2da_test_set.pkl
评估会打印结果和解决问题所需的步数到新的 evals/
目录。
进行训练:
python scripts/train.py --env csg2da --batch_size 32 --num_workers 16 --max_primitives 8 --n_layers 8 --d_model 256 --num_heads 16 --test_every 1000 --forward_mode path --nowandb