Rubyのparsletを使って単位と物理定数の両方に対応した計算機

Googleで"10km/5sec/c"とかやると単位も物理定数も考えて計算してくれるのは有名な話だけど、当然の事ながらオフラインだとできない。不便なので作ってみようと思った。単位なのか物理定数なのかを判断するようにするのは面倒だったから単位は角括弧で囲むようにした。

parsletというのを使って構文解析して単に計算する。
動作はコマンドラインのみでこんな感じ。

> 180[km]/20[sec]
9000.0 [m s-1]
> 250[kpc]/1000[km/s]=?Myr
244.4532762791914 [Myr]
> G*5.972e24[kg]/6371[km]**2
9.819296622997971 [m s-2]


一応ソースコードをはっておく。ここを多いに参考にした。クラスの使い方がヘタクソでMKSAの香りがAmountの中に漏れ出てきてるし、どこでも括弧を使えるようにするために左再帰の対処療法的に"1*"とかが入ってたりしてかっこわるいことこの上ない。だれかいい感じに書き直してくれ。できれば関数にも対応してほしいな。

割と切実にいい感じにしてほしいしここにべた書きするよりgithubに上げるべきか。

#!/usr/bin/env ruby

require 'parslet'

class Amount < Numeric
  attr_accessor :val, :unit
  def initialize(val, unit)
    u = Unit_MKSA.new(unit)
    @unit = u.unit
    @val = val * u.const
  end


  def +(other)
    if self.unit == other.unit
      a = Amount.new(self.val+other.val, [])
      a.unit = self.unit
      a
    else
      raise('incompatible dimensions')
    end
  end

  def -(other)
    if self.unit == other.unit
      a = Amount.new(self.val-other.val, [])
      a.unit = self.unit
      a
    else
      raise('incompatible dimensions')
    end
  end

  def -@
    a = Amount.new(-self.val, [])
    a.unit = self.unit
    a
  end

  def *(other)
    a = Amount.new(self.val*other.val, [])
    hash = {'m' => 0, 'kg' => 0, 's' => 0, 'A' => 0, 'K' => 0}
    self.unit.each_slice(2) do |u,i|
      hash[u] = hash[u] + i
    end
    other.unit.each_slice(2) do |u,i|
      hash[u] = hash[u] + i
    end
    ary = []
    hash.each do |k,v|
      ary << k << v
    end
    a.unit = ary
    a
  end

  def /(other)
    a = Amount.new(self.val/other.val, [])
    hash = {'m' => 0, 'kg' => 0, 's' => 0, 'A' => 0, 'K' => 0}
    self.unit.each_slice(2) do |u,i|
      hash[u] = hash[u] + i
    end
    other.unit.each_slice(2) do |u,i|
      hash[u] = hash[u] - i
    end
    ary = []
    hash.each do |k,v|
      ary << k << v
    end
    a.unit = ary
    a
  end

  def **(other)
    if other.is_a?(Amount) and dimensionless?(other.unit)
      a = Amount.new(self.val**other.val, [])
      a.unit = power_unit_array(self.unit, other.val)
      a
    elsif other.is_a?(Integer)
      a = Amount.new(self.val**other, [])
      a.unit = power_unit_array(self.unit, other)
      a
    elsif
      raise('index should be an Integer')
    end
  end

  def <=>(other)
    if self.unit == other.unit
      if self.val < other.val
        -1
      elsif self.val == other.val
        0
      else
        1
      end
    else
      raise('incompatible dimensions')
      nil
    end
  end

  def %(other)
    if self.unit == other.unit
      a = Amount.new(self.val%other.val, [])
      a.unit = self.unit
      a
    else
      raise('incompatible dimensions')
    end
  end

  def abs
    if self.val < 0
      a = Amount.new(-self.val, [])
      a.unit = self.unit
      a
    else
      self
    end
  end

  def to_f
    val.to_f
  end
  def to_i
    val.to_i
  end

  def power_unit_array(ary, pow)
    out = []
    ary.each_slice(2) do |u,i|
      out << u << i*pow
    end
    out
  end

  def dimensionless?(ary)
    if ary[1]!=0 or ary[3]!=0 or ary[5]!=0 or ary[7]!=0 or ary[9]!=0
      nil
    else
      true
    end
  end
end

class Unit_MKSA
  @@mksa = {
    'mm' => [1e-3,['m',1]],
    'cm' => [1e-2,['m',1]],
    'm' => [1,['m',1]],
    'km' => [1000,['m',1]],
    'AU' => [149597870700, ['m',1]],
    'au' => [149597870700, ['m',1]],
    'pc' => [3.08567758e16,['m',1]],
    'kpc' => [3.08567758e19,['m',1]],
    'Mpc' => [3.08567758e22,['m',1]],
    'lyr' => [9.4605284e15,['m',1]],
    'g' => [1e-3,['kg',1]],
    'kg' => [1,['kg',1]],
    'M_sun' => [1.9891e30, ['kg',1]],
    'Msun' => [1.9891e30, ['kg',1]],
    's' => [1,['s',1]],
    'sec' => [1,['s',1]],
    'min' => [60,['s',1]],
    'hour' => [60*60,['s',1]],
    'day' => [86400,['s',1]],
    'yr' => [31556926,['s',1]],
    'Myr' => [3.1556926e13,['s',1]],
    'Gyr' => [3.1556926e16,['s',1]],
    'A' => [1,['A',1]],
    'J' => [1,['kg',1,'m',2,'s',-2]],
    'W' => [1,['kg',1,'m',2,'s',-3]],
    'N' => [1,['kg',1,'m',1,'s',-2]],
    'eV' => [1.60217657e-19,['kg',1,'m',2,'s',-2]],
    'keV' => [1.60217657e-16,['kg',1,'m',2,'s',-2]],
    'MeV' => [1.60217657e-13,['kg',1,'m',2,'s',-2]],
    'GeV' => [1.60217657e-10,['kg',1,'m',2,'s',-2]],
    'TeV' => [1.60217657e-7,['kg',1,'m',2,'s',-2]],
    'PeV' => [1.60217657e-4,['kg',1,'m',2,'s',-2]],
    'C' => [1, ['A',1,'s',1]],
    'K' => [1, ['K',1]],
    'V' => [1,['kg',1,'m',2,'s',-3,'A',-1]],
  }

  attr_accessor :unit, :const

  def initialize(unit_ary)
    const = 1.0
    ary = []
    unit_ary.each_slice(2) do |u,i|
      begin unit = @@mksa[u]
        const = const * (unit[0]**i)
        ary = ary + power_unit_array(unit[1], i)
      rescue => ex
        puts "\e[1;31m################ ERROR ################"
        puts "unit named \'#{u}\' is not defined"
        break
      end
    end
    @const = const
    @unit = simplify_unit_array(ary)
  end

  def power_unit_array(ary, pow)
    out = []
    ary.each_slice(2) do |u,i|
      out << u << i*pow
    end
    out
  end

  def simplify_unit_array(ary)
    hash = {'m' => 0, 'kg' => 0, 's' => 0, 'A' => 0, 'K' => 0}
    ary.each_slice(2) do |u,i|
      hash[u] = hash[u] + i
    end
    out = []
    hash.each do |k,v|
      out << k << v
    end
    out
  end

end

class Constant < Amount
  @@const = {
    'c' => [299792458, ['m',1,'s',-1]],
    'c0' => [299792458, ['m',1,'s',-1]],
    'c_0' => [299792458, ['m',1,'s',-1]],
    'G' => [6.67384e-11, ['N',1,'m',2,'kg',-2]],
    'h' => [6.62606957e-34, ['J',1,'s',1]],
    'q' => [1.60217657e-19, ['C',1]],
    'm_e' => [9.10938291e-31, ['kg',1]],
    'm_p' => [1.672621777e-27, ['kg',1]],
    'k' => [1.3806488e-23, ['J',1,'K',-1]],
    'k_b' => [1.3806488e-23, ['J',1,'K',-1]],
    'e' => [2.71828182846, []],
    'pi' => [3.14159265359, []],
    'M_sun' => [1.9891e30, ['kg',1]],
    'Msun' => [1.9891e30, ['kg',1]],
    'AU' => [149597870700, ['m',1]],
    'au' => [149597870700, ['m',1]],
  }

  def initialize(str)
    begin const = @@const[str]
      super(const[0], const[1])
    rescue => ex
      puts "\e[1;31m################ ERROR ################"
      puts "constant named \'#{str}\' is not defined"
      super(0,[])
    end
  end
end

class Query < Parslet::Parser
  rule(:space) { str(' ').repeat(1) }
  rule(:space?) { str(' ').repeat(0) }
  rule(:op) { str('=?') >> space?}

  ################ form ################
  rule(:form) { (match("[a-zA-Z0-9]") | match("[-*^/+_.\s]") |
                 str(']') | str('[') | str(')') | str('(')).repeat }

  rule(:form_nonbracket) { (match("[a-zA-Z0-9]") | match("[-*^/+_.\s]")).repeat }

  rule(:query) {
    form.as(:lhs) >> op.as(:op) >> (form_nonbracket.as(:rhs_nb) | form.as(:rhs)) | form.as(:lhs)
  }
  root(:query)
end

class Expression < Parslet::Parser
  rule(:space) { str(' ').repeat(1) }
  rule(:space?) { str(' ').repeat(0) }
  rule(:lparen)     { str('(') >> space? }
  rule(:rparen)     { str(')') >> space? }

  ################ number ################
  rule(:integer) {
    (str('+') | str('-')).maybe >>
    match("[0-9]").repeat(1)
  }

  rule(:float) {
    (
     (
      integer >>
      (str('.') >> match("[0-9]").repeat(1)).maybe
      ).as(:dec) >>
     (match("[eE]") >> integer.as(:pow)).maybe
     ).as(:float)
  }

  rule(:number) {
    float >> space?
  }

  ################ constant ################
  rule(:constant) {
    match("[0-9a-zA-Z_]").repeat(1).as(:const) >> space?
  }

  ################ unit ################
  rule(:lbracket)     { str('[') >> space? }
  rule(:rbracket)     { str(']') >> space? }
  def bracketed(atom)
    lbracket >> atom >> rbracket
  end

  rule(:unit_single) {
    match["a-zA-Z"].repeat(1).as(:u_atom) >>
    (str('**') | str('^')).maybe >>
    integer.as(:u_pow).maybe
  }

  rule(:unit_expression) {
    unit_operation | unit_single
  }

  rule(:unit_operation) {
    unit_single.as(:u_1) >>
    unit_operator.as(:u_op) >>
    unit_expression.as(:u_2) >> space?
  }

  rule(:unit_operator) {
    (space? >> match["/*"] | match(" ")) >> space?
  }

  rule(:unit) {
    bracketed(unit_expression.as(:unit))
  }

  ################ amount ################
  rule(:amount) {
    (number >> unit | number | constant).as(:amount)
  }

  ################ expression ################
  def parened(atom)
    lparen >> atom >> rparen
  end

  rule(:factor) { amount | parened(expression.as(:expression)) }
  rule(:expression) { factor >>
    ((str('**') | str('/') | str('+') | str('-') | str('*') | str('^')).as(:operator) >>
     space? >> factor).repeat}

  root(:expression)
end

class Transform_Expression < Parslet::Transform
  rule(:dec => simple(:d)) { d.to_f }
  rule(:dec => simple(:d), :pow => simple(:p)) { "#{d}e#{p}".to_f }
  rule(:u_atom => simple(:u)) { [u.to_s, 1] }
  rule(:u_atom => simple(:u), :u_pow => simple(:p)) { [u.to_s, p.to_i] }
  rule(:u_1 => sequence(:u1), :u_op => simple(:op), :u_2 => sequence(:u2)) {
    seq = u1
    if (op.to_s.include? "/")
      seq = seq + u2
      seq[3] = -seq[3]
    else
      seq = seq + u2
    end
    seq
  }
  rule(:float => simple(:f)) { Amount.new(f, []) }
  rule(:float => simple(:f), :unit => sequence(:u)) { Amount.new(f, u) }
  rule(:const => simple(:s)) { Constant.new(s.to_s) }
end

OP_PRECEDENCE = {
  '+' => 10,
  '-' => 10,
  '*' => 20,
  '/' => 20,
  '^' => 30,
  '**' => 30,
}

def BinaryOperation(left, operator, right)
  case operator
  when '+'
    left + right
  when '-'
    left - right
  when '*'
    left * right
  when '/'
    left / right
  when '^', '**'
    left ** right
  end
end

def construct_ast_recursive(expression, precedence_limit)

  if expression.length == 0
    return nil, nil
  end

  first = expression[0]
  unless first
    raise "at least two factors required"
  end

  case
  when first.has_key?(:amount)
    lhs = first[:amount]
  when first.has_key?(:expression)
    lhs = construct_ast(first[:expression])
  else
    raise "no expression or amount in expression item"
  end

  expression = expression[1..-1]
  if expression.length == 0
    expression = nil
  end

  while expression
    op = expression[0][:operator].to_s
    precedence = OP_PRECEDENCE[op]
    if precedence <= precedence_limit
      return lhs, expression
    end
    rhs, expression = construct_ast_recursive(expression, precedence)
    lhs = BinaryOperation(lhs, op, rhs)
  end

  return lhs, nil

end

def construct_ast(expression)
  construct_ast_recursive(expression, -1)[0]
end

def parse(str)
  parsed = Transform_Expression.new.apply(Expression.new.parse(str))
  ast = construct_ast(parsed)
end

def parse_query(str)
  parsed = Query.new.parse(str)
end

def format_unit(ary)
  str = "["
  ary.each_slice(2) do |k,v|
    next if v == 0
    if v == 1
      str << k << " "
    elsif v.to_i == v
      str << k << v.to_i.to_s << " "
    else
      str << k << v.to_s << " "
    end
  end
  str = str.strip
  str << "]"
  str
end

def show_help
  print "calculator\n"
end

if $0 == __FILE__

  require "readline"
  while true
    str = Readline.readline("\e[0m> ", true)
    if str == ":h"
      show_help
      next
    end
    break if str == ":q"
    parsed_query = parse_query(str)
    lhs = parse("1*" + parsed_query[:lhs].to_s)
    op = parsed_query[:op].to_s
    result = ""
    case
    when op.include?("=?")
      if parsed_query[:rhs]
        rhs = parse("1*1" + parsed_query[:rhs].to_s)
        if lhs.unit != rhs.unit
          puts "\e[1;31m################ ERROR ################"
          puts "dimension mismatch: lhs #{format_unit(lhs.unit)} rhs #{format_unit(rhs.unit)}"
          next
        end
        result = lhs / rhs
        print "\e[1;32m"
        print result.val, " ", parsed_query[:rhs].to_s, "\n"
        print "\e[0m"
      else
        rhs = parse("1*1[" + parsed_query[:rhs_nb].to_s + "]")
        if lhs.unit != rhs.unit
          puts "\e[1;31m################ ERROR ################"
          puts "dimension mismatch: lhs #{format_unit(lhs.unit)} rhs #{format_unit(rhs.unit)}"
          next
        end
        result = lhs / rhs
        print "\e[1;32m"
        print result.val, " [", parsed_query[:rhs_nb].to_s, "]\n"
        print "\e[0m"
      end
    else
      result = lhs
      print "\e[1;32m"
      print result.val, " ", format_unit(result.unit), "\n"
      print "\e[0m"
    end
  end

end