Skip to content

BitShift

Bases: BaseModule

Multiply-by-power-of-two layer that emulates compiler-side bit shifts.

Multiplies the input by 2 ** (±nb_shift), casts the result to int to match the compiler's truncation, and uses a straight-through estimator so gradients bypass the cast.

Parameters:

Name Type Description Default
nb_shift int

Shift amount (positive integer).

required
direction str

"right" (divide) or "left" (multiply).

'right'