StableHLO মধ্যে গতিশীলতা

ডায়নামিজমের বর্তমান অবস্থাটি আরও আনুষ্ঠানিকভাবে ডায়নামিজম RFC- তে বানান করা হয়েছে, এই পৃষ্ঠাটি RFC-এর একটি উচ্চ স্তরের ওভারভিউ প্রদান করবে এবং ডায়নামিক প্রোগ্রামগুলির সাথে ইন্টারঅ্যাক্ট করার জন্য গুরুত্বপূর্ণ API এবং টুলিং নিয়ে আলোচনা করবে।

ডায়নামিজম পরিভাষা এবং সমর্থন ওভারভিউ

প্রথমত, এই ডকে প্রদর্শিত কিছু শর্তাদি কভার করতে, সেইসাথে StableHLO-তে তাদের সমর্থনের একটি সংক্ষিপ্ত ভূমিকা:

গতিশীল মাত্রা

ডাইনামিক ডাইমেনশন বলতে এমন কোনো ডাইমেনশনকে বোঝায় যার ডাইমেনশন সাইজ অজানা। StableHLO এ আমরা ব্যবহার করে গতিশীল মাত্রা উপস্থাপন করি ? , যেমন tensor<16x?xf32>

আবদ্ধ গতিশীলতা

আবদ্ধ গতিশীলতা একটি গতিশীল মাত্রা বোঝায় যার মান একটি পরিচিত উপরের সীমা আছে। সাধারণত এটি কার্যকর করার সময় টেনসর প্যাডিংয়ের জন্য দরকারী। StableHLO-তে আমরা টেনসর এনকোডিং হিসাবে #stablehlo.bounds ব্যবহার করে আবদ্ধ গতিশীলতার প্রতিনিধিত্ব করি, অর্থাৎ একটি র্যাঙ্ক-2 টেনসর যার একটি গতিশীল মাত্রা 16 এ আবদ্ধ এবং অন্যটি সীমাহীন tensor<?x?xf32, #stablehlo.bounds<16, ?>>

StableHLO সীমাবদ্ধ গতিশীলতার প্রতিনিধিত্ব করতে সক্ষম, তবে সীমিত কাঠামো সমর্থন রয়েছে, যা TensorFlow থেকে উদ্ভূত হয়েছে এবং PyTorch/XLA-তে কিছু সমর্থন রয়েছে।

সীমাহীন গতিশীলতা

সীমাহীন গতিশীলতা নামটি বোঝায় একটি গতিশীল ডাইমেনশনকে বোঝায় যার আকারের কোন পরিচিত আবদ্ধ নেই। JAX, PyTorch/XLA, এবং TF সমর্থন সহ StableHLO-তে এই ধরনের গতিশীলতা খুবই সাধারণ, যা প্রায়শই গতিশীল ব্যাচের আকার বা সিকোয়েন্স দৈর্ঘ্য সহ মডেল রপ্তানির জন্য ব্যবহৃত হয়।

StableHLO-তে আমরা এই ধরনের গতিশীলতার জন্য সীমানা এনকোডিংকে এলিড করি, যেমন tensor<?x?xf32>

আকৃতি পলিমরফিজম

শেপ পলিমরফিজম হল একটি শব্দ যা আমরা JAX থেকে উত্তরাধিকার সূত্রে পেয়েছি

পলিমরফিজমকে আকৃতি দেওয়ার জন্য দুটি মূল প্রভাব রয়েছে:

  1. প্রোগ্রামের সমস্ত গতিশীলতা তার ইনপুট আর্গুমেন্টে ফিরে আসে।
  2. সমস্ত গতিশীলতা শুধুমাত্র টেনসর আকারের সাথে সম্পর্কিত, অর্থাৎ ডেটা-নির্ভর নয়।

এই দুটি নিয়মের সাহায্যে, একবার একটি প্রোগ্রামের স্ট্যাটিক আকৃতি জানা হয়ে গেলে, আমরা একটি গতিশীল প্রোগ্রাম নিতে এবং সংকলনের জন্য এটিকে একটি স্ট্যাটিক প্রোগ্রামে সম্পূর্ণরূপে পরিমার্জন করতে সক্ষম হই (দেখুন "ডাইনামিক প্রোগ্রামগুলিকে পরিশোধনের জন্য কম্পাইলার পাস" )।

সাধারণত শেপ পলিমরফিজম সীমাহীন গতিশীলতা ব্যবহার করে, যদি পরিচিত আর্গুমেন্ট আকারগুলি একটি সম্পূর্ণ স্ট্যাটিক প্রোগ্রামের দিকে নিয়ে যেতে পারে, তাহলে মানগুলিকে কীভাবে আবদ্ধ করা যায় তা অনুমান করার দরকার নেই।

ডেটা-নির্ভর গতিশীলতা

ডেটা-নির্ভর গতিশীলতা বলতে ডায়নামিক ডাইমেনশনের আকার বোঝায় যা একটি টেনসরের ভিতরের ডেটার সাথে সম্পর্কিত। ক্যানোনিকাল উদাহরণ হল একটি nonzeros ফাংশন যা একটি টেনসর মান 0 এর সমস্ত উপাদানের সূচক প্রদান করে। ডেটা মূল্যায়ন না করে আকারটি জানা যাবে না, তবে এটি প্রায়শই আবদ্ধ গতিশীলতা ব্যবহার করে সংকলন করা যেতে পারে, সম্ভাব্য আউটপুট টেনসর আকারের উপর অতিরিক্ত মেমরি ব্যয় করে।

অনেক ডেটা-নির্ভর ডায়নামিক অপ্স বাউন্ডেড ডাইনামিজম ব্যবহার করে মডেল করা যেতে পারে, যেখানে একটি টেনসর সাইজের উপরি বাউন্ড নির্দিষ্ট করা হয় এবং হার্ডওয়্যার সাধারণত টেনসর প্যাডিংয়ের মাধ্যমে এটি বাস্তবায়ন করবে। বর্তমানে PyTorch/XLA এবং TensorFlow-এ ডেটা-নির্ভর গতিশীলতার জন্য কিছু সমর্থন রয়েছে, কিন্তু JAX বর্তমানে এমন ক্রিয়াকলাপগুলি খুঁজে পায় না যা ডেটা নির্ভরশীল গতিশীলতার দিকে পরিচালিত করে।

গতিশীল মাত্রা সহ প্রোগ্রাম রপ্তানি

ডায়নামিক ব্যাচের আকার বা ক্রম দৈর্ঘ্য সহ প্রোগ্রামগুলি কীভাবে রপ্তানি করা যায় সে সম্পর্কে তথ্যের জন্য আমাদের StableHLO টিউটোরিয়ালগুলি দেখুন:

ডায়নামিক প্রোগ্রাম পরিমার্জন করার জন্য কম্পাইলার পাস

গতিশীলতা পাস পাইপলাইন সরান

আকার পরিমার্জন করার জন্য কয়েকটি দরকারী পাস রয়েছে, সুবিধামত সেগুলি সবগুলি একটি পাস পাইপলাইনে বান্ডিল করে createStablehloRemoveDynamismPipeline :

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

পরিমার্জিত গতিশীলতার জন্য পৃথক পাস

স্বতন্ত্রভাবে, যে পাসগুলি আকৃতির পরিমার্জনের জন্য উপযোগী হতে থাকে তা হল:

  • কংক্রিট টেনসর প্রকারের সাথে ইনপুট আর্গুমেন্ট প্রতিস্থাপন করতে stablehlo-refine-arguments
  • stablehlo-refine-shapes সমগ্র প্রোগ্রাম জুড়ে নতুন ইনপুট আর্গুমেন্ট আকৃতির তথ্য প্রচার করতে।
  • stablehlo-canonicalize-dynamism ডায়নামিক অপগুলিকে তাদের স্ট্যাটিক ভেরিয়েন্ট দিয়ে প্রতিস্থাপন করতে।

আপ-টু-ডেট তথ্য এবং উদাহরণের জন্য লিঙ্কড ডকুমেন্টেশন দেখুন।

উদাহরণ: কিভাবে গতিশীলতা দরকারী, এবং আমি কিভাবে এটি ব্যবহার করতে পারি?

ডায়নামিজমের প্রচুর ব্যবহার রয়েছে, এখানে আমরা মূলত শেপ পলিমরফিজমের সাধারণ ব্যবহারের ক্ষেত্রে ফোকাস করব - একটি নমনীয় রপ্তানিকৃত মডেল উপস্থাপনা তৈরি করা, যা সাধারণত গতিশীল ব্যাচের আকার বা ক্রম দৈর্ঘ্য উপস্থাপন করতে ব্যবহৃত হয়।

স্ট্যাটিক add_one মডেল

আমরা এটি প্রদর্শন করতে নিম্নলিখিত সহজ add_one মডেল ব্যবহার করব:

def add_one(x):
  return x + 1

যখন একটি tensor<4xf32> ব্যবহার করে চিহ্নিত করা হয় তখন আমরা নিম্নলিখিত StableHLO প্রোগ্রামটি পাব:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

এই মডেলটি শুধুমাত্র ইনপুট আর্গুমেন্টের জন্য কাজ করবে যার একটি tensor<4xf32> আকৃতি আছে। যদি আমরা কখনও আমাদের ব্যাচের আকার বা ক্রম দৈর্ঘ্য পরিবর্তন করি, তাহলে আমাদের সোর্স কোডটি পুনরায় ট্রেস করতে হবে এবং StableHLO-তে পুনরায় কমিয়ে আনতে হবে, এবং এমন কোন গ্যারান্টি নেই যে আমাদের এখনও সোর্স কোডে অ্যাক্সেস আছে!

ডাইনামিক add_one মডেল

এখানেই আকৃতির বহুরূপী গতিশীলতা খেলায় আসে। পরিবর্তে JAX এবং PyTorch/XLA গতিশীলভাবে বৈধ IR সহ add_one মডেল নির্গত করতে পারে যা নিম্নরূপ গতিশীল ইনপুট আকারের সাথে মেলে ধ্রুবক সম্প্রচার করবে:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

এই মডেল উপস্থাপনা অনেক বেশি নমনীয়, এবং ব্যাচের আকার বা ক্রম দৈর্ঘ্যের মতো মানগুলির বিলম্বিত স্পেসিফিকেশনের অনুমতি দেয়। এই মডেলটি গতিশীল আকৃতি সমর্থন সহ প্ল্যাটফর্মে স্থাপন করা যেতে পারে (যেমন AI এজ ), অথবা এই ডকুমেন্টেশনে উল্লিখিত গতিশীলতা পাসগুলি ব্যবহার করে এটি পরিমার্জিত করা যেতে পারে।

গতিশীল মডেল পরিমার্জন

উদাহরণস্বরূপ নিম্নলিখিত পাস অর্ডারিং এই প্রোগ্রামটিকে সম্পূর্ণরূপে পরিমার্জন করতে পারে:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

ক্রমবর্ধমানভাবে, এইভাবে প্রোগ্রামটি রূপান্তরিত হয়:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}