๊ด€๋ฆฌ ๋ฉ”๋‰ด

Doby's Lab

ArgumentIndexError: tensorflow์˜ LSTM์„ ์‚ฌ์šฉํ•  ๋•Œ ์ฃผ์˜ ๊นŠ๊ฒŒ ๋ณด์•„์•ผ ํ•  Parameter, return_sequences ๋ณธ๋ฌธ

Code about AI/tensorflow

ArgumentIndexError: tensorflow์˜ LSTM์„ ์‚ฌ์šฉํ•  ๋•Œ ์ฃผ์˜ ๊นŠ๊ฒŒ ๋ณด์•„์•ผ ํ•  Parameter, return_sequences

๋„๋น„(Doby) 2023. 8. 24. 12:59

๐Ÿค” Problem

tensorflow์—์„œ ์ œ๊ณตํ•˜๋Š” LSTM์„ ์‚ฌ์šฉ์„ ํ•˜๋ฉด์„œ ๋งˆ์ฃผํ–ˆ๋˜ ์—๋Ÿฌ ArgumentIndexError์— ๋Œ€ํ•ด ์ •๋ฆฌ๋ฅผ ํ•ด๋ณด๊ณ ์ž ํ•ฉ๋‹ˆ๋‹ค.

์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ๋ฅผ ํ•™์Šตํ•˜๋ ค๋Š” ์‹œ๋„์˜ ๊ณผ์ •์—์„œ ์—๋Ÿฌ๊ฐ€ ๋ฐœ์ƒํ–ˆ์œผ๋ฉฐ, ์—๋Ÿฌ์˜ ์ž์„ธํ•œ ๋‚ด์šฉ์€ Incompatible shapes: [32,4] vs. [32,20,4] ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

 

๊ทธ๋ฆฌ๊ณ , ์—๋Ÿฌ๊ฐ€ ๋‚ฌ์—ˆ๋˜ ๋ชจ๋ธ์˜ ์ฝ”๋“œ๋Š” ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

model = tf.keras.models.Sequential()

model.add(tf.keras.layers.LSTM(units=64, 
                               input_shape=train_x.shape[-2:], 
                               dropout=0.1, 
                               recurrent_dropout=0.1,
                               activation='tanh',
                               return_sequences=True))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LSTM(units=32, 
                               dropout=0.1, 
                               recurrent_dropout=0.1,
                               activation='tanh'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dense(10))
model.add(tf.keras.layers.Dense(4))

๐Ÿ˜€ Solution

Incompatible shapes: [32,4] vs. [32,20,4] ์ด ๋ฌธ์ œ๊ฐ€ ์ผ์–ด๋‚ฌ๋˜ ์ด์œ ๋Š” ๋ชจ๋ธ์ด ํ•™์Šตํ•˜๋Š” ๊ณผ์ •์—์„œ ๊ธฐ์กด์˜ ์˜๋„ํ•œ target์˜ shape๊ฐ€ ํ›„์ž์— ํ•ด๋‹นํ•˜๊ณ , 2๋ฒˆ์งธ LSTM์—์„œ ์˜๋„ํ•œ ๋ฐ”์™€ ๋‹ค๋ฅด๊ฒŒ output์ด ๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด์„œ 1๋ฒˆ์งธ LSTM๊ณผ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ recurrent_sequences๋ฅผ True๋กœ ๋ฐ”๊ฟ”์ฃผ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

 

์ด ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ์˜๋ฏธํ•˜๋Š” ๋ฐ”๋Š” LSTM์ด ๊ธฐ๋ฐ˜์ด ๋˜๋Š” RNN์˜ ๊ตฌ์กฐ๋ฅผ ์‚ดํŽด๋ณด๋ฉด ์ดํ•ด๊ฐ€ ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. return_squences๊ฐ€ default๋กœ False์ผ ๊ฒฝ์šฐ, ๋งˆ์ง€๋ง‰์˜ ๋ชจ๋“  input๊ณผ ๋ชจ๋“  state์— ํ•ด๋‹นํ•˜๋Š” ์••์ถ•๊ฐ’๊ณผ ์œ ์‚ฌํ•œ ๋งˆ์ง€๋ง‰ output๋งŒ์„ LSTM Layer์˜ output์œผ๋กœ ๋‹ค์Œ Layer๋กœ ๋ณด๋‚ด๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

๊ทธ๋ ‡๊ธฐ ๋•Œ๋ฌธ์— ์˜๋„ํ•œ LSTM์˜ output์ด [32, 20, 4]๊ฐ€ ๋‚˜์˜ค์ง€ ์•Š์•˜๋˜ ์ด์œ ๋Š” ํƒ€์ž„์Šคํ…์— ๋”ฐ๋ฅธ output์„ ๋‚ด๋†“์ง€ ์•Š๊ณ , ๋‹จํŽธ์ ์ธ ํ•˜๋‚˜์˜ output์„ ๋‚ด๋†“์•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

return_squences๊ฐ€ False์ผ ๋•Œ, ์•„๋ž˜ ๊ทธ๋ฆผ์˜ ๋นจ๊ฐ„ ๋ฐ•์Šค์— ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค.

์ถœ์ฒ˜: ์œ„ํ‚ค๋…์Šค

๊ทธ๋ž˜์„œ return_sequences๊ฐ€ True์ผ ๊ฒฝ์šฐ์— ํƒ€์ž„ ์Šคํ…์— ๋”ฐ๋ผ output์„ ๋‚ด๋†“์œผ๋ฉฐ, ์•„๋ž˜์˜ ๊ทธ๋ฆผ์—์„œ ๋นจ๊ฐ„ ๋ฐ•์Šค๊ฐ€ LSTM์˜ output์— ํ•ด๋‹น๋ฉ๋‹ˆ๋‹ค.

์‚ฌ์‹ค ์ด๋Ÿฌํ•œ ์—๋Ÿฌ๊ฐ€ ๊ฐ„๋‹จํ•œ ์—๋Ÿฌ๋ผ ๋ณผ ์ˆ˜ ์žˆ์—ˆ์œผ๋‚˜, ์ €๋Š” ํ•ด๊ฒฐ์„ ํ•˜๋ฉด์„œ ๊ธฐ์กด์˜ ๊ณต๋ถ€ ๋ฐฉํ–ฅ์— ๋Œ€ํ•ด ๊ณ ๋ฏผ์„ ํ•˜๋Š” ๊ณ„๊ธฐ๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ, ์ €๋ฒˆ XGB Wrapper์— ๊ด€ํ•œ ์—๋Ÿฌ์—์„œ๋„ ๋น„์Šทํ•œ ์ƒ๊ฐ์„ ํ–ˆ์—ˆ์Šต๋‹ˆ๋‹ค.

 

์ง€๊ธˆ๊นŒ์ง€ ์ด๋ก  ์ชฝ์œผ๋กœ๋งŒ ๋„ˆ๋ฌด ์น˜์šฐ์น˜๊ฒŒ ๊ณต๋ถ€๋ฅผ ํ•ด์˜จ ๊ฒŒ ์•„๋‹๊นŒ๋ผ๋Š” ์ƒ๊ฐ์„ ํ–ˆ์Šต๋‹ˆ๋‹ค. ํ˜„๋Œ€ ๋ฌผ๋ฆฌํ•™์€ ์–ด๋–ค์ง€ ๋ชจ๋ฅด๊ฒ ์œผ๋‚˜ ์˜›๋‚  ๋ฌผ๋ฆฌํ•™์€ ํฌ๊ฒŒ ์ด๋ก  ๋ฌผ๋ฆฌํ•™๊ณผ ์‹คํ—˜ ๋ฌผ๋ฆฌํ•™์œผ๋กœ ๋‚˜๋‰˜์—ˆ๋‹ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค. ์ €๋Š” ์ด๋ก  ์ชฝ์œผ๋กœ๋งŒ ํŽธํ–ฅ์ ์œผ๋กœ ๊ณต๋ถ€๋ฅผ ํ•ด์™”๊ธฐ์— ์‹คํ—˜์— ํ•ด๋‹นํ•˜๋Š” ์ฝ”๋“œ๋ฅผ ๋‹ค๋ฃจ๋Š” ๋ถ€๋ถ„์€ ๋ถ€์กฑํ–ˆ๊ธฐ์— ์ด๋Ÿฐ ์—๋Ÿฌ๋“ค์„ ๊ฒช์€ ๊ฑฐ๋ผ ํ™•์‹ ํ•ฉ๋‹ˆ๋‹ค.

 

๊ธฐ์กด์˜ ์ด ๋ธ”๋กœ๊ทธ์˜ ํฌ์ŠคํŒ…์„ ๋ณด๋”๋ผ๋„ AI์— ๊ด€๋ จ๋œ ํฌ์ŠคํŒ…์—์„œ๋Š” ์ด๋ก ์— ๋Œ€ํ•ด ์ด์•ผ๊ธฐ๊ฐ€ ๋งŽ์•˜์ง€๋งŒ, ์ฝ”๋“œ๋กœ ์–ด๋–ป๊ฒŒ ๊ตฌํ˜„์ด ๋  ์ˆ˜ ์žˆ๋Š”์ง€ ํ˜น์€ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค๋ฉด, ์–ด๋–ป๊ฒŒ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š”์ง€์— ๋Œ€ํ•œ ๋ถ€๋ถ„์ด ๋ถ€์กฑํ–ˆ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์—๋Ÿฌ๋“ค์„ ๊ณ„๊ธฐ๋กœ ์ƒˆ๋กœ์šด ์ด๋ก ๋“ค์„ ๊ณต๋ถ€ํ•˜๊ฒŒ ๋  ๋•Œ๋Š” ์ด๋ก ๊ณผ ์‹คํ—˜(์ฝ”๋“œ) ๋‘˜ ๋‹ค ์ฑ™๊ฒจ๊ฐ€๋Š” ๊ณต๋ถ€๋ฅผ ํ•ด์•ผ๊ฒ ๋‹ค๋Š” ๊ด€์ ์ด ๋งŒ๋“ค์–ด์ง„ ๋“ฏํ•ฉ๋‹ˆ๋‹ค.

728x90