2024年4月14日日曜日

8GBのGPUカード+Windows環境でSDXLのLoRA学習にチャレンジ

最近はSDXLを使った作品作りの試行錯誤をしています。いろいろやっていくとどうしても絵柄を調整したくなる。SD1.5では実はオリジナルのLoRAを作っていました。

私は手元ではメモリ8GBのGeForece RTX2070 Superを使っています。通常のRAMは16GBです。

SDXLのLoRA学習は画像生成時以上にメモリが必要とのことでしたが、調べるとU-net部分だけに絞るなどのテクを使えば可能、とあったのですが、いろいろ調べて試してみたら、fp8_baseというオプションを使うことで、そこそこの速度で、この環境でもテキストエンコーダも含めたLoRAが学習できたのでご紹介したいと思います。

SD 1.5でのLoRAづくりをしたことがある人を前提に記載しています。
また、グラボのドライバとgit,Python 3.10以上がインストールされていることも前提です。この辺りはwebuiなどを使っていればおそらく問題はないと思います。

1. 準備
・仮想メモリの割り当てを増やしておく
    システムのプロパティ-パフォーマンス-詳細設定タブ-仮想メモリで、最大仮想メモリが30GB弱割り当てられるようにしておく。(学習の初期に一瞬かなりのメモリを消費するポイントがあるようです)

2. sd-scriptsのインストール
git,pipなどを使って、インストールしていきます。

適当なフォルダを作り、powershellを起動し、そのフォルダの中で以下のコマンドを起動。スクリプト本体と作業用のpython仮想環境を作成します。
git clone  https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts

python -m venv venv
.\venv\Scripts\activate

必要なライブラリを 入れていきます。

pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118

pip install --upgrade -r requirements.txt

pip install xformers==0.0.22.post7 --index-url https://download.pytorch.org/whl/cu118

一部のファイルをコピーします。

cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\

cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py

cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py

 いくつかフォルダも掘っておきます

mkdir dataset 

mkdir outputs

mkdir logs 


3.データの準備

mkdir dataset  

を実行します。その下に、1024x1024以上のサイズの画像と、キャプションをセットにしたデータを配置しますが、この例ではさらにサブフォルダに

    10_class名

という名前のフォルダを作り、その下にデータを置きます。dataset.tomlを指定する方法もあるのですが、上記フォルダの先頭が示す、その画像に対する繰り返し回数を画像セットごとに指定できるので、私は旧来からのこの方法でデータを配置しています。

また、一つのサブフォルダ内の画像セットの縦横のピクセル数は同一である必要があります。なお、ぴったり1024x1024ある必要は、この例ではありません。

また、キャプションテキストファイルの拡張子は.captionとしています。キャプションを付与するタガーの使い方などはここでは省略します。

4.config.yamlファイルの準備

 以下の内容の、config.yamlファイルを作ります。

command_file: null
commands: null
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: 'NO'
downcast_bf16: 'no'
dynamo_backend: 'NO'
fsdp_config: {}
gpu_ids: all
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
megatron_lm_config: {}
mixed_precision: 'no'
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_name: null
tpu_zone: null
use_cpu: false

5.GPUメモリをできるだけ空ける

できるだけ、実行されているアプリを減らし、GPUメモリを空けます

  1. ブラウザ、常駐タスクを終了する。(タスクトレイアイコンも忘れずに)
  2. ディスプレイサイズを800x600にする
  3. 視覚効果を全部切る。(ウインドウマネージャが使うメモリを減らす)

   特に3点目、PC-プロパティ-システム詳細設定-システムのプロパティ-パフォーマンスオプション-「パフォーマンスを優先する」

です。ここまでやると、かなり学習以外のGPUメモリの消費量を下げれると思います。私はこんな感じです。


目標は、「専用GPUメモリ」だけをLoRAの学習に使わせることです。共有GPUメモリを使いだすととたんに学習に必要な時間が3~10倍になってしまいます。逆に、計算時間を我慢すれば、GPUメモリが少なくても共有GPUメモリの仕組みを使ってもう少し制約が少ないLoRAの学習ができないわけでもないです。

6. 学習の実行

以下のコマンドを実行します。bat,cmdファイルにしておいてもよいと思います。

accelerate launch --config_file=".\config.yaml" sdxl_train_network.py `
    --pretrained_model_name_or_path="学習の対象にしたいチェックポイント.safetensors"  `
    --vae="(VAEのありかのフルパス)sdxl_vae.safetensors" `
    --fp8_base `
    --output_dir=".\outputs"    `
    --output_name=sdxl_test_lora  `
    --save_model_as=safetensors  `
    --prior_loss_weight=1.0  `
    --max_train_steps=10  `
    --learning_rate=1e-4  `
    --optimizer_type="adafactor"  `
    --xformers  `
    --mixed_precision="fp16"  `
    --cache_latents_to_disk `
    --enable_bucket `
    --caption_extension=".caption" `
    --gradient_checkpointing `
    --save_every_n_epochs=1  `
    --network_module=networks.lora `
    --no_half_vae `
    --network_dim=8 `
    --network_alpha=2 `
    --bucket_reso_steps=64 `
    --logging_dir=".\logs" `
    --max_train_epochs=2 `
    --train_batch_size=3 `
    --train_data_dir=".\dataset_tmp" --resolution="1024,1024" --min_bucket_reso=256 --max_bucket_reso=2048 

ここではLoRAのRankのディメンジョン数を8、バッチサイズを3にしています。私の環境ではこれが限界でした。また、--fb8_base オプションが重要でして、これがあることで、必要GPUメモリがぐっと減らせます。学習をU-net部分に限定する --network_train_unet_only を指定する必要がなくなりました。

ちなみにこの設定の時の学習の状況は以下のような感じ。


まあ、ギリギリですかね。

さて、学習方法はこれで何とかつかめたのですが、なかなか良い感じの絵柄が生成できなくて困っています。しばらくパラメータを振ったりしていきたいと思います。

補足:Winodows環境では以下のワーニングが出てきます。tritonはWindowsではサポートされていないそうですが、無視しても大丈夫とのことです。 

A matching Triton is not available, some optimizations will not be enabled. Error caught was: No module named 'triton' 


 

0 件のコメント:

コメントを投稿